diff --git a/builtins.go b/builtins.go index 3305513..60fa894 100644 --- a/builtins.go +++ b/builtins.go @@ -3,131 +3,65 @@ package musttag // builtins is a set of functions supported out of the box. var builtins = []Func{ // https://pkg.go.dev/encoding/json - { - Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0, - ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0, - ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1, - ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0, - ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "(*encoding/json.Decoder).Decode", Tag: "json", ArgPos: 0, - ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}, - }, + {"encoding/json.Marshal", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}}, + {"encoding/json.MarshalIndent", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}}, + {"encoding/json.Unmarshal", "json", 1, []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"(*encoding/json.Encoder).Encode", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}}, + {"(*encoding/json.Decoder).Decode", "json", 0, []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}}, // https://pkg.go.dev/encoding/xml - { - Name: "encoding/xml.Marshal", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "encoding/xml.MarshalIndent", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "encoding/xml.Unmarshal", Tag: "xml", ArgPos: 1, - ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "(*encoding/xml.Encoder).Encode", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "(*encoding/xml.Decoder).Decode", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "(*encoding/xml.Encoder).EncodeElement", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}, - }, - { - Name: "(*encoding/xml.Decoder).DecodeElement", Tag: "xml", ArgPos: 0, - ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, + {"encoding/xml.Marshal", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}}, + {"encoding/xml.MarshalIndent", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}}, + {"encoding/xml.Unmarshal", "xml", 1, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"(*encoding/xml.Encoder).Encode", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}}, + {"(*encoding/xml.Decoder).Decode", "xml", 0, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"(*encoding/xml.Encoder).EncodeElement", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}}, + {"(*encoding/xml.Decoder).DecodeElement", "xml", 0, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}}, // https://pkg.go.dev/gopkg.in/yaml.v3 - { - Name: "gopkg.in/yaml.v3.Marshal", Tag: "yaml", ArgPos: 0, - ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"}, - }, - { - Name: "gopkg.in/yaml.v3.Unmarshal", Tag: "yaml", ArgPos: 1, - ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"}, - }, - { - Name: "(*gopkg.in/yaml.v3.Encoder).Encode", Tag: "yaml", ArgPos: 0, - ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"}, - }, - { - Name: "(*gopkg.in/yaml.v3.Decoder).Decode", Tag: "yaml", ArgPos: 0, - ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"}, - }, + {"gopkg.in/yaml.v3.Marshal", "yaml", 0, []string{"gopkg.in/yaml.v3.Marshaler"}}, + {"gopkg.in/yaml.v3.Unmarshal", "yaml", 1, []string{"gopkg.in/yaml.v3.Unmarshaler"}}, + {"(*gopkg.in/yaml.v3.Encoder).Encode", "yaml", 0, []string{"gopkg.in/yaml.v3.Marshaler"}}, + {"(*gopkg.in/yaml.v3.Decoder).Decode", "yaml", 0, []string{"gopkg.in/yaml.v3.Unmarshaler"}}, // https://pkg.go.dev/github.com/BurntSushi/toml - { - Name: "github.com/BurntSushi/toml.Unmarshal", Tag: "toml", ArgPos: 1, - ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "github.com/BurntSushi/toml.Decode", Tag: "toml", ArgPos: 1, - ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "github.com/BurntSushi/toml.DecodeFS", Tag: "toml", ArgPos: 2, - ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "github.com/BurntSushi/toml.DecodeFile", Tag: "toml", ArgPos: 1, - ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, - { - Name: "(*github.com/BurntSushi/toml.Encoder).Encode", Tag: "toml", ArgPos: 0, - ifaceWhitelist: []string{"encoding.TextMarshaler"}, - }, - { - Name: "(*github.com/BurntSushi/toml.Decoder).Decode", Tag: "toml", ArgPos: 0, - ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}, - }, + {"github.com/BurntSushi/toml.Unmarshal", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"github.com/BurntSushi/toml.Decode", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"github.com/BurntSushi/toml.DecodeFS", "toml", 2, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"github.com/BurntSushi/toml.DecodeFile", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}}, + {"(*github.com/BurntSushi/toml.Encoder).Encode", "toml", 0, []string{"encoding.TextMarshaler"}}, + {"(*github.com/BurntSushi/toml.Decoder).Decode", "toml", 0, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}}, // https://pkg.go.dev/github.com/mitchellh/mapstructure - {Name: "github.com/mitchellh/mapstructure.Decode", Tag: "mapstructure", ArgPos: 1}, - {Name: "github.com/mitchellh/mapstructure.DecodeMetadata", Tag: "mapstructure", ArgPos: 1}, - {Name: "github.com/mitchellh/mapstructure.WeakDecode", Tag: "mapstructure", ArgPos: 1}, - {Name: "github.com/mitchellh/mapstructure.WeakDecodeMetadata", Tag: "mapstructure", ArgPos: 1}, + {"github.com/mitchellh/mapstructure.Decode", "mapstructure", 1, nil}, + {"github.com/mitchellh/mapstructure.DecodeMetadata", "mapstructure", 1, nil}, + {"github.com/mitchellh/mapstructure.WeakDecode", "mapstructure", 1, nil}, + {"github.com/mitchellh/mapstructure.WeakDecodeMetadata", "mapstructure", 1, nil}, // https://pkg.go.dev/github.com/jmoiron/sqlx - {Name: "github.com/jmoiron/sqlx.Get", Tag: "db", ArgPos: 1}, - {Name: "github.com/jmoiron/sqlx.GetContext", Tag: "db", ArgPos: 2}, - {Name: "github.com/jmoiron/sqlx.Select", Tag: "db", ArgPos: 1}, - {Name: "github.com/jmoiron/sqlx.SelectContext", Tag: "db", ArgPos: 2}, - {Name: "github.com/jmoiron/sqlx.StructScan", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Conn).GetContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Conn).SelectContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.DB).Get", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.DB).GetContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.DB).Select", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.DB).SelectContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.NamedStmt).Get", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.NamedStmt).GetContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.NamedStmt).Select", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.NamedStmt).SelectContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Row).StructScan", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Rows).StructScan", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Stmt).Get", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Stmt).GetContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Stmt).Select", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Stmt).SelectContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Tx).Get", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Tx).GetContext", Tag: "db", ArgPos: 1}, - {Name: "(*github.com/jmoiron/sqlx.Tx).Select", Tag: "db", ArgPos: 0}, - {Name: "(*github.com/jmoiron/sqlx.Tx).SelectContext", Tag: "db", ArgPos: 1}, + {"github.com/jmoiron/sqlx.Get", "db", 1, []string{"database/sql.Scanner"}}, + {"github.com/jmoiron/sqlx.GetContext", "db", 2, []string{"database/sql.Scanner"}}, + {"github.com/jmoiron/sqlx.Select", "db", 1, []string{"database/sql.Scanner"}}, + {"github.com/jmoiron/sqlx.SelectContext", "db", 2, []string{"database/sql.Scanner"}}, + {"github.com/jmoiron/sqlx.StructScan", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Conn).GetContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Conn).SelectContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.DB).Get", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.DB).GetContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.DB).Select", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.DB).SelectContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.NamedStmt).Get", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.NamedStmt).GetContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.NamedStmt).Select", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.NamedStmt).SelectContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Row).StructScan", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Rows).StructScan", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Stmt).Get", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Stmt).GetContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Stmt).Select", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Stmt).SelectContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Tx).Get", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Tx).GetContext", "db", 1, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Tx).Select", "db", 0, []string{"database/sql.Scanner"}}, + {"(*github.com/jmoiron/sqlx.Tx).SelectContext", "db", 1, []string{"database/sql.Scanner"}}, } diff --git a/musttag.go b/musttag.go index 70c8420..f8c0352 100644 --- a/musttag.go +++ b/musttag.go @@ -91,17 +91,17 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, call, ok := node.(*ast.CallExpr) if !ok { - return // not a function call. + return } callee := typeutil.StaticCallee(pass.TypesInfo, call) if callee == nil { - return // not a static call. + return } fn, ok := funcs[cutVendor(callee.FullName())] if !ok { - return // unsupported function. + return } if len(call.Args) <= fn.ArgPos { @@ -116,7 +116,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, typ := pass.TypesInfo.TypeOf(arg) if typ == nil { - return // no type info found. + return } checker := checker{ @@ -125,9 +125,8 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, ifaceWhitelist: fn.ifaceWhitelist, imports: pass.Pkg.Imports(), } - - if valid := checker.checkType(typ, fn.Tag); valid { - return // nothing to report. + if checker.isValidType(typ, fn.Tag) { + return } pass.Reportf(arg.Pos(), "the given struct should be annotated with the `%s` tag", fn.Tag) @@ -143,43 +142,32 @@ type checker struct { imports []*types.Package } -func (c *checker) checkType(typ types.Type, tag string) bool { +func (c *checker) isValidType(typ types.Type, tag string) bool { if _, ok := c.seenTypes[typ.String()]; ok { - return true // already checked. + return true } c.seenTypes[typ.String()] = struct{}{} styp, ok := c.parseStruct(typ) if !ok { - return true // not a struct. + return true } - return c.checkStruct(styp, tag) + return c.isValidStruct(styp, tag) } -// recursively unwrap a type until we get to an underlying -// raw struct type that should have its fields checked -// -// SomeStruct -> struct{SomeStructField: ... } -// []*SomeStruct -> struct{SomeStructField: ... } -// ... -// -// exits early if it hits a type that implements a whitelisted interface func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) { if implementsInterface(typ, c.ifaceWhitelist, c.imports) { - return nil, false // the type implements a Marshaler interface; see issue #64. + return nil, false } switch typ := typ.(type) { case *types.Pointer: return c.parseStruct(typ.Elem()) - case *types.Array: return c.parseStruct(typ.Elem()) - case *types.Slice: return c.parseStruct(typ.Elem()) - case *types.Map: return c.parseStruct(typ.Elem()) @@ -205,7 +193,7 @@ func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) { } } -func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) { +func (c *checker) isValidStruct(styp *types.Struct, tag string) bool { for i := 0; i < styp.NumFields(); i++ { field := styp.Field(i) if !field.Exported() { @@ -214,18 +202,18 @@ func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) { tagValue, ok := reflect.StructTag(styp.Tag(i)).Lookup(tag) if !ok { - // tag is not required for embedded types; see issue #12. + // tag is not required for embedded types. if !field.Embedded() { return false } } - // Do not recurse into ignored fields. + // the field is explicitly ignored. if tagValue == "-" { continue } - if valid := c.checkType(field.Type(), tag); !valid { + if !c.isValidType(field.Type(), tag) { return false } } @@ -254,25 +242,29 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa } for _, ifacePath := range ifaces { - // "encoding/json.Marshaler" -> "encoding/json" + "Marshaler" + // e.g. "encoding/json.Marshaler" -> "encoding/json" + "Marshaler". idx := strings.LastIndex(ifacePath, ".") if idx == -1 { continue } + pkgName, ifaceName := ifacePath[:idx], ifacePath[idx+1:] scope, ok := findScope(pkgName) if !ok { continue } + obj := scope.Lookup(ifaceName) if obj == nil { continue } + iface, ok := obj.Type().Underlying().(*types.Interface) if !ok { continue } + if types.Implements(typ, iface) || types.Implements(types.NewPointer(typ), iface) { return true } diff --git a/testdata/src/tests/builtins.go b/testdata/src/tests/builtins.go index d576da4..482839b 100644 --- a/testdata/src/tests/builtins.go +++ b/testdata/src/tests/builtins.go @@ -28,6 +28,10 @@ type TextMarshaler struct{ NoTag string } func (TextMarshaler) MarshalText() ([]byte, error) { return nil, nil } func (*TextMarshaler) UnmarshalText([]byte) error { return nil } +type Scanner struct{ NotTag string } + +func (*Scanner) Scan(any) error { return nil } + func testJSON() { var st Struct json.Marshal(st) // want "the given struct should be annotated with the `json` tag" @@ -154,6 +158,33 @@ func testSQLX() { new(sqlx.Tx).GetContext(nil, &st, "") // want "the given struct should be annotated with the `db` tag" new(sqlx.Tx).Select(&st, "") // want "the given struct should be annotated with the `db` tag" new(sqlx.Tx).SelectContext(nil, &st, "") // want "the given struct should be annotated with the `db` tag" + + var sc Scanner + sqlx.Get(nil, &sc, "") + sqlx.GetContext(nil, nil, &sc, "") + sqlx.Select(nil, &sc, "") + sqlx.SelectContext(nil, nil, &sc, "") + sqlx.StructScan(nil, &sc) + new(sqlx.Conn).GetContext(nil, &sc, "") + new(sqlx.Conn).SelectContext(nil, &sc, "") + new(sqlx.DB).Get(&sc, "") + new(sqlx.DB).GetContext(nil, &sc, "") + new(sqlx.DB).Select(&sc, "") + new(sqlx.DB).SelectContext(nil, &sc, "") + new(sqlx.NamedStmt).Get(&sc, nil) + new(sqlx.NamedStmt).GetContext(nil, &sc, nil) + new(sqlx.NamedStmt).Select(&sc, nil) + new(sqlx.NamedStmt).SelectContext(nil, &sc, nil) + new(sqlx.Row).StructScan(&sc) + new(sqlx.Rows).StructScan(&sc) + new(sqlx.Stmt).Get(&sc) + new(sqlx.Stmt).GetContext(nil, &sc) + new(sqlx.Stmt).Select(&sc) + new(sqlx.Stmt).SelectContext(nil, &sc) + new(sqlx.Tx).Get(&sc, "") + new(sqlx.Tx).GetContext(nil, &sc, "") + new(sqlx.Tx).Select(&sc, "") + new(sqlx.Tx).SelectContext(nil, &sc, "") } func testCustom() {