Skip to content

Commit

Permalink
fix: do not report types implementing (Un)Marshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane committed Oct 7, 2023
1 parent 11f0e6c commit 799a478
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
2 changes: 1 addition & 1 deletion builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package musttag
// builtins is a set of functions supported out of the box.
var builtins = []Func{
// https://pkg.go.dev/encoding/json
{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0},
{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0, ifaceWhitelist: []string{"Marshaler"}},
{Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0},
{Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1},
{Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0},
Expand Down
41 changes: 39 additions & 2 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,25 @@ type Func struct {
Name string // Name is the full name of the function, including the package.
Tag string // Tag is the struct tag whose presence should be ensured.
ArgPos int // ArgPos is the position of the argument to check.

// a list of interfaces from the same package;
// if at least one is implemented by the argument, no check is performed.
ifaceWhitelist []string
}

func (fn Func) shortName() string {
name := strings.NewReplacer("*", "", "(", "", ")", "").Replace(fn.Name)
return path.Base(name)
}

func (fn Func) pkgPath() string {
name := strings.NewReplacer("*", "", "(", "", ")", "").Replace(fn.Name)
if idx := strings.LastIndex(name, "."); idx != -1 {
return name[:idx]
}
return ""
}

// New creates a new musttag analyzer.
// To report a custom function provide its description via Func,
// it will be added to the builtin ones.
Expand Down Expand Up @@ -144,13 +156,38 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er
initialPos = arg.Pos()
}

argType := pass.TypesInfo.TypeOf(arg)
if argType == nil {
return // no type info found.
}

for _, pkg := range pass.Pkg.Imports() {
if pkg.Path() != fn.pkgPath() {
continue
}
for _, ifaceName := range fn.ifaceWhitelist {
obj := pkg.Scope().Lookup(ifaceName)
if obj == nil {
continue
}
iface, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
continue
}
if types.Implements(argType, iface) {
pass.Reportf(initialPos, "implements %s", iface)
return // the argument implements an (Un)Marshaler interface, no need to check; see issue #64.
}
}
break
}

checker := checker{
mainModule: mainModule,
seenTypes: make(map[string]struct{}),
}

t := pass.TypesInfo.TypeOf(arg)
st, ok := checker.parseStructType(t, initialPos)
st, ok := checker.parseStructType(argType, initialPos)
if !ok {
return // not a struct argument.
}
Expand Down
10 changes: 10 additions & 0 deletions testdata/src/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,13 @@ func nothingToReport() {
json.NewEncoder(nil).Encode(Foo{})
json.NewDecoder(nil).Decode(&Foo{})
}

type marshaler struct{}

func (marshaler) MarshalJSON() ([]byte, error) { return nil, nil }

func implementsInterface() {
var m marshaler
json.Marshal(m)
json.Marshal(marshaler{})
}

0 comments on commit 799a478

Please sign in to comment.