From d3527355ae1a11cbfed46f3bfb228de40067ae79 Mon Sep 17 00:00:00 2001 From: Tom Date: Sat, 25 Feb 2023 21:11:12 +0400 Subject: [PATCH] feat: return an error on bad Func.ArgPos (#30) --- musttag.go | 14 ++++++++++---- musttag_test.go | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/musttag.go b/musttag.go index 1254288..0a4173f 100644 --- a/musttag.go +++ b/musttag.go @@ -3,6 +3,7 @@ package musttag import ( "flag" + "fmt" "go/ast" "go/token" "go/types" @@ -81,7 +82,7 @@ func flags(funcs *[]Func) flag.FlagSet { // for tests only. var ( - reportf = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) { + report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) { const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s" pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos) } @@ -106,6 +107,10 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) { filter := []ast.Node{(*ast.CallExpr)(nil)} walk.Preorder(filter, func(n ast.Node) { + if err != nil { + return // there is already an error. + } + call, ok := n.(*ast.CallExpr) if !ok { return // not a function call. @@ -122,7 +127,8 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) { } if len(call.Args) <= fn.ArgPos { - return // TODO(junk1tm): return a proper error. + err = fmt.Errorf("Func.ArgPos cannot be %d: %s accepts only %d argument(s)", fn.ArgPos, fn.Name, len(call.Args)) + return } arg := call.Args[fn.ArgPos] @@ -159,10 +165,10 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) { p := pass.Fset.Position(call.Pos()) p.Filename, _ = filepath.Rel(moduleDir, p.Filename) - reportf(pass, result, fn, p) + report(pass, result, fn, p) }) - return nil, nil + return nil, err } // structType is an extension for types.Struct. diff --git a/musttag_test.go b/musttag_test.go index 0b81bbf..c7ad37f 100644 --- a/musttag_test.go +++ b/musttag_test.go @@ -19,22 +19,36 @@ func TestAnalyzer(t *testing.T) { prepareTestFiles(t) testPackages = []string{"tests", "builtins"} - analyzer := New( - Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0}, - Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1}, - ) + testdata := analysistest.TestData() + + t.Run("tests", func(t *testing.T) { + r := report + defer func() { report = r }() + report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) { + pass.Reportf(st.Pos, fn.shortName()) + } + analyzer := New() + analysistest.Run(t, testdata, analyzer, "tests") + }) t.Run("builtins", func(t *testing.T) { - testdata := analysistest.TestData() + analyzer := New( + Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0}, + Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1}, + ) analysistest.Run(t, testdata, analyzer, "builtins") }) - t.Run("tests", func(t *testing.T) { - reportf = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) { - pass.Reportf(st.Pos, fn.shortName()) + t.Run("bad Func.ArgPos", func(t *testing.T) { + const want = `Func.ArgPos cannot be 10: encoding/json.Marshal accepts only 1 argument(s)` + analyzer := New( + // override the builtin function. + Func{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 10}, + ) + result := analysistest.Run(nopT{}, testdata, analyzer, "tests")[0] + if got := result.Err.Error(); got != want { + t.Errorf("\ngot\t%s\nwant\t%s", got, want) } - testdata := analysistest.TestData() - analysistest.Run(t, testdata, analyzer, "tests") }) } @@ -46,7 +60,7 @@ func TestFlags(t *testing.T) { t.Run("ok", func(t *testing.T) { err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:0"}) if err != nil { - t.Errorf("got %v; want no error", err) + t.Errorf("\ngot\t%s\nwant\tno error", err) } }) @@ -54,7 +68,7 @@ func TestFlags(t *testing.T) { const want = `invalid value "test.Test" for flag -fn: invalid syntax` err := analyzer.Flags.Parse([]string{"-fn=test.Test"}) if got := err.Error(); got != want { - t.Errorf("got %q; want %q", got, want) + t.Errorf("\ngot\t%s\nwant\t%s", got, want) } }) @@ -62,11 +76,15 @@ func TestFlags(t *testing.T) { const want = `invalid value "test.Test:test:-" for flag -fn: strconv.Atoi: parsing "-": invalid syntax` err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:-"}) if got := err.Error(); got != want { - t.Errorf("got %q; want %q", got, want) + t.Errorf("\ngot\t%s\nwant\t%s", got, want) } }) } +type nopT struct{} + +func (nopT) Errorf(string, ...any) {} + func prepareTestFiles(t *testing.T) { testdata := analysistest.TestData()