From 1105e1a8abee8ee26057b637a0c0bb07723ce7bc Mon Sep 17 00:00:00 2001 From: Tom <73077675+tmzane@users.noreply.github.com> Date: Sun, 8 Oct 2023 23:51:23 +0300 Subject: [PATCH] fix: do not report types implementing `(Un)Marshaler` (#67) --- builtins.go | 154 +++++++++++++++++++++++++----- musttag.go | 67 ++++++++++++- testdata/src/builtins/builtins.go | 70 ++++++++++++++ 3 files changed, 265 insertions(+), 26 deletions(-) diff --git a/builtins.go b/builtins.go index 66914fa..db86317 100644 --- a/builtins.go +++ b/builtins.go @@ -3,34 +3,144 @@ package musttag // builtins is a set of functions supported out of the box. var builtins = []Func{ // https://pkg.go.dev/encoding/json - {Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0}, - {Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0}, - {Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1}, - {Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0}, - {Name: "(*encoding/json.Decoder).Decode", Tag: "json", ArgPos: 0}, + { + Name: "encoding/json.Marshal", + Tag: "json", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "encoding/json.MarshalIndent", + Tag: "json", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "encoding/json.Unmarshal", + Tag: "json", + ArgPos: 1, + ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "(*encoding/json.Encoder).Encode", + Tag: "json", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "(*encoding/json.Decoder).Decode", + Tag: "json", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}, + }, // https://pkg.go.dev/encoding/xml - {Name: "encoding/xml.Marshal", Tag: "xml", ArgPos: 0}, - {Name: "encoding/xml.MarshalIndent", Tag: "xml", ArgPos: 0}, - {Name: "encoding/xml.Unmarshal", Tag: "xml", ArgPos: 1}, - {Name: "(*encoding/xml.Encoder).Encode", Tag: "xml", ArgPos: 0}, - {Name: "(*encoding/xml.Decoder).Decode", Tag: "xml", ArgPos: 0}, - {Name: "(*encoding/xml.Encoder).EncodeElement", Tag: "xml", ArgPos: 0}, - {Name: "(*encoding/xml.Decoder).DecodeElement", Tag: "xml", ArgPos: 0}, + { + Name: "encoding/xml.Marshal", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "encoding/xml.MarshalIndent", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "encoding/xml.Unmarshal", + Tag: "xml", + ArgPos: 1, + ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "(*encoding/xml.Encoder).Encode", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "(*encoding/xml.Decoder).Decode", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "(*encoding/xml.Encoder).EncodeElement", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, + }, + { + Name: "(*encoding/xml.Decoder).DecodeElement", + Tag: "xml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, // https://github.com/go-yaml/yaml - {Name: "gopkg.in/yaml.v3.Marshal", Tag: "yaml", ArgPos: 0}, - {Name: "gopkg.in/yaml.v3.Unmarshal", Tag: "yaml", ArgPos: 1}, - {Name: "(*gopkg.in/yaml.v3.Encoder).Encode", Tag: "yaml", ArgPos: 0}, - {Name: "(*gopkg.in/yaml.v3.Decoder).Decode", Tag: "yaml", ArgPos: 0}, + { + Name: "gopkg.in/yaml.v3.Marshal", + Tag: "yaml", + ArgPos: 0, + ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"}, + }, + { + Name: "gopkg.in/yaml.v3.Unmarshal", + Tag: "yaml", + ArgPos: 1, + ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"}, + }, + { + Name: "(*gopkg.in/yaml.v3.Encoder).Encode", + Tag: "yaml", + ArgPos: 0, + ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"}, + }, + { + Name: "(*gopkg.in/yaml.v3.Decoder).Decode", + Tag: "yaml", + ArgPos: 0, + ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"}, + }, // https://github.com/BurntSushi/toml - {Name: "github.com/BurntSushi/toml.Unmarshal", Tag: "toml", ArgPos: 1}, - {Name: "github.com/BurntSushi/toml.Decode", Tag: "toml", ArgPos: 1}, - {Name: "github.com/BurntSushi/toml.DecodeFS", Tag: "toml", ArgPos: 2}, - {Name: "github.com/BurntSushi/toml.DecodeFile", Tag: "toml", ArgPos: 1}, - {Name: "(*github.com/BurntSushi/toml.Encoder).Encode", Tag: "toml", ArgPos: 0}, - {Name: "(*github.com/BurntSushi/toml.Decoder).Decode", Tag: "toml", ArgPos: 0}, + { + Name: "github.com/BurntSushi/toml.Unmarshal", + Tag: "toml", + ArgPos: 1, + ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "github.com/BurntSushi/toml.Decode", + Tag: "toml", + ArgPos: 1, + ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "github.com/BurntSushi/toml.DecodeFS", + Tag: "toml", + ArgPos: 2, + ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "github.com/BurntSushi/toml.DecodeFile", + Tag: "toml", + ArgPos: 1, + ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, + { + Name: "(*github.com/BurntSushi/toml.Encoder).Encode", + Tag: "toml", + ArgPos: 0, + ifaceWhitelist: []string{"encoding.TextMarshaler"}, + }, + { + Name: "(*github.com/BurntSushi/toml.Decoder).Decode", + Tag: "toml", + ArgPos: 0, + ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, + }, // https://github.com/mitchellh/mapstructure {Name: "github.com/mitchellh/mapstructure.Decode", Tag: "mapstructure", ArgPos: 1}, diff --git a/musttag.go b/musttag.go index 7f4e05e..248d11c 100644 --- a/musttag.go +++ b/musttag.go @@ -24,6 +24,10 @@ type Func struct { Name string // Name is the full name of the function, including the package. Tag string // Tag is the struct tag whose presence should be ensured. ArgPos int // ArgPos is the position of the argument to check. + + // a list of interface names (including the package); + // if at least one is implemented by the argument, no check is performed. + ifaceWhitelist []string } func (fn Func) shortName() string { @@ -93,7 +97,7 @@ var report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Posi pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos) } -var cleanFullName = regexp.MustCompile(`([^*/(]+/vendor/)`) +var trimVendor = regexp.MustCompile(`([^*/(]+/vendor/)`) // run starts the analysis. func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, error) { @@ -117,7 +121,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er return // not a static call. } - name := cleanFullName.ReplaceAllString(callee.FullName(), "") + name := trimVendor.ReplaceAllString(callee.FullName(), "") fn, ok := funcs[name] if !ok { return // the function is not supported. @@ -144,13 +148,21 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er initialPos = arg.Pos() } + typ := pass.TypesInfo.TypeOf(arg) + if typ == nil { + return // no type info found. + } + + if implementsInterface(typ, fn.ifaceWhitelist, pass.Pkg.Imports()) { + return // the type implements a Marshaler interface, nothing to check; see issue #64. + } + checker := checker{ mainModule: mainModule, seenTypes: make(map[string]struct{}), } - t := pass.TypesInfo.TypeOf(arg) - st, ok := checker.parseStructType(t, initialPos) + st, ok := checker.parseStructType(typ, initialPos) if !ok { return // not a struct argument. } @@ -257,3 +269,50 @@ func (c *checker) checkStructType(st *structType, tag string) (*structType, bool return nil, true } + +func implementsInterface(typ types.Type, ifaces []string, imports []*types.Package) bool { + findScope := func(pkgName string) (*types.Scope, bool) { + // fast path: check direct imports (e.g. looking for "encoding/json.Marshaler"). + for _, direct := range imports { + if pkgName == trimVendor.ReplaceAllString(direct.Path(), "") { + return direct.Scope(), true + } + } + // slow path: check indirect imports (e.g. looking for "encoding.TextMarshaler"). + for _, direct := range imports { + for _, indirect := range direct.Imports() { + if pkgName == trimVendor.ReplaceAllString(indirect.Path(), "") { + return indirect.Scope(), true + } + } + } + return nil, false + } + + for _, ifacePath := range ifaces { + // "encoding/json.Marshaler" -> "encoding/json" + "Marshaler" + idx := strings.LastIndex(ifacePath, ".") + if idx == -1 { + continue + } + pkgName, ifaceName := ifacePath[:idx], ifacePath[idx+1:] + + scope, ok := findScope(pkgName) + if !ok { + continue + } + obj := scope.Lookup(ifaceName) + if obj == nil { + continue + } + iface, ok := obj.Type().Underlying().(*types.Interface) + if !ok { + continue + } + if types.Implements(typ, iface) { + return true + } + } + + return false +} diff --git a/testdata/src/builtins/builtins.go b/testdata/src/builtins/builtins.go index 9d888d6..60e3c7a 100644 --- a/testdata/src/builtins/builtins.go +++ b/testdata/src/builtins/builtins.go @@ -76,6 +76,23 @@ type User struct { /* want Email string `json:"email" xml:"email" yaml:"email" toml:"email" mapstructure:"email" db:"email" custom:"email"` } +// TODO: Unmarshaler should be implemented using pointer semantics. + +type TextMarshaler struct{ NoTag string } + +func (TextMarshaler) MarshalText() ([]byte, error) { return nil, nil } +func (TextMarshaler) UnmarshalText([]byte) error { return nil } + +type Marshaler struct{ NoTag string } + +func (Marshaler) MarshalJSON() ([]byte, error) { return nil, nil } +func (Marshaler) UnmarshalJSON([]byte) error { return nil } +func (Marshaler) MarshalXML(e *xml.Encoder, start xml.StartElement) error { return nil } +func (Marshaler) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { return nil } +func (Marshaler) MarshalYAML() (any, error) { return nil, nil } +func (Marshaler) UnmarshalYAML(*yaml.Node) error { return nil } +func (Marshaler) UnmarshalTOML(any) error { return nil } + func testJSON() { var user User json.Marshal(user) @@ -83,6 +100,20 @@ func testJSON() { json.Unmarshal(nil, &user) json.NewEncoder(nil).Encode(user) json.NewDecoder(nil).Decode(&user) + + var m Marshaler + json.Marshal(m) + json.MarshalIndent(m, "", "") + json.Unmarshal(nil, &m) + json.NewEncoder(nil).Encode(m) + json.NewDecoder(nil).Decode(&m) + + var tm TextMarshaler + json.Marshal(tm) + json.MarshalIndent(tm, "", "") + json.Unmarshal(nil, &tm) + json.NewEncoder(nil).Encode(tm) + json.NewDecoder(nil).Decode(&tm) } func testXML() { @@ -94,6 +125,24 @@ func testXML() { xml.NewDecoder(nil).Decode(&user) xml.NewEncoder(nil).EncodeElement(user, xml.StartElement{}) xml.NewDecoder(nil).DecodeElement(&user, &xml.StartElement{}) + + var m Marshaler + xml.Marshal(m) + xml.MarshalIndent(m, "", "") + xml.Unmarshal(nil, &m) + xml.NewEncoder(nil).Encode(m) + xml.NewDecoder(nil).Decode(&m) + xml.NewEncoder(nil).EncodeElement(m, xml.StartElement{}) + xml.NewDecoder(nil).DecodeElement(&m, &xml.StartElement{}) + + var tm TextMarshaler + xml.Marshal(tm) + xml.MarshalIndent(tm, "", "") + xml.Unmarshal(nil, &tm) + xml.NewEncoder(nil).Encode(tm) + xml.NewDecoder(nil).Decode(&tm) + xml.NewEncoder(nil).EncodeElement(tm, xml.StartElement{}) + xml.NewDecoder(nil).DecodeElement(&tm, &xml.StartElement{}) } func testYAML() { @@ -102,6 +151,12 @@ func testYAML() { yaml.Unmarshal(nil, &user) yaml.NewEncoder(nil).Encode(user) yaml.NewDecoder(nil).Decode(&user) + + var m Marshaler + yaml.Marshal(m) + yaml.Unmarshal(nil, &m) + yaml.NewEncoder(nil).Encode(m) + yaml.NewDecoder(nil).Decode(&m) } func testTOML() { @@ -112,6 +167,21 @@ func testTOML() { toml.DecodeFile("", &user) toml.NewEncoder(nil).Encode(user) toml.NewDecoder(nil).Decode(&user) + + var m Marshaler + toml.Unmarshal(nil, &m) + toml.Decode("", &m) + toml.DecodeFS(nil, "", &m) + toml.DecodeFile("", &m) + toml.NewDecoder(nil).Decode(&m) + + var tm TextMarshaler + toml.Unmarshal(nil, &tm) + toml.Decode("", &tm) + toml.DecodeFS(nil, "", &tm) + toml.DecodeFile("", &tm) + toml.NewEncoder(nil).Encode(tm) + toml.NewDecoder(nil).Decode(&tm) } func testMapstructure() {