From dde3e87b4d8ba368c0df39eb96c0dfa4fe9a3d3e Mon Sep 17 00:00:00 2001 From: Thomas <73077675+tmzane@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:50:08 +0300 Subject: [PATCH] refactor: use `cutVendor()` instead of regexp and move vendor (#71) --- musttag.go | 10 ++--- musttag_test.go | 2 +- testdata/src/{tests => }/.gitignore | 0 testdata/src/{tests => }/go.mod | 2 +- testdata/src/{tests => }/go.sum | 0 testdata/src/go.work | 2 +- utils.go | 21 ++++++++- utils_test.go | 69 ++++++++++++++++------------- 8 files changed, 63 insertions(+), 43 deletions(-) rename testdata/src/{tests => }/.gitignore (100%) rename testdata/src/{tests => }/go.mod (83%) rename testdata/src/{tests => }/go.sum (100%) diff --git a/musttag.go b/musttag.go index 43c8381..1c2b3c0 100644 --- a/musttag.go +++ b/musttag.go @@ -7,7 +7,6 @@ import ( "go/ast" "go/types" "reflect" - "regexp" "strconv" "strings" @@ -81,8 +80,6 @@ func flags(funcs *[]Func) flag.FlagSet { return *fs } -var trimVendor = regexp.MustCompile("([^*/(]+/vendor/)") - func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, err error) { visit := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) filter := []ast.Node{(*ast.CallExpr)(nil)} @@ -102,8 +99,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, return // not a static call. } - name := trimVendor.ReplaceAllString(callee.FullName(), "") - fn, ok := funcs[name] + fn, ok := funcs[cutVendor(callee.FullName())] if !ok { return // unsupported function. } @@ -221,7 +217,7 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa findScope := func(pkgName string) (*types.Scope, bool) { // fast path: check direct imports (e.g. looking for "encoding/json.Marshaler"). for _, direct := range imports { - if pkgName == trimVendor.ReplaceAllString(direct.Path(), "") { + if pkgName == cutVendor(direct.Path()) { return direct.Scope(), true } } @@ -229,7 +225,7 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa // TODO: only check indirect imports from the package (e.g. "encoding/json") of the analyzed function (e.g. "encoding/json.Marshal"). for _, direct := range imports { for _, indirect := range direct.Imports() { - if pkgName == trimVendor.ReplaceAllString(indirect.Path(), "") { + if pkgName == cutVendor(indirect.Path()) { return indirect.Scope(), true } } diff --git a/musttag_test.go b/musttag_test.go index f41d1a4..76f1a1c 100644 --- a/musttag_test.go +++ b/musttag_test.go @@ -63,7 +63,7 @@ func (nopT) Errorf(string, ...any) {} func setupModules(t *testing.T, testdata string) { t.Helper() - err := os.Chdir(filepath.Join(testdata, "src", "tests")) + err := os.Chdir(filepath.Join(testdata, "src")) assert.NoErr[F](t, err) err = exec.Command("go", "mod", "vendor").Run() diff --git a/testdata/src/tests/.gitignore b/testdata/src/.gitignore similarity index 100% rename from testdata/src/tests/.gitignore rename to testdata/src/.gitignore diff --git a/testdata/src/tests/go.mod b/testdata/src/go.mod similarity index 83% rename from testdata/src/tests/go.mod rename to testdata/src/go.mod index 351c0ad..62766c8 100644 --- a/testdata/src/tests/go.mod +++ b/testdata/src/go.mod @@ -10,4 +10,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) -replace example.com/custom => ../custom +replace example.com/custom => ./custom diff --git a/testdata/src/tests/go.sum b/testdata/src/go.sum similarity index 100% rename from testdata/src/tests/go.sum rename to testdata/src/go.sum diff --git a/testdata/src/go.work b/testdata/src/go.work index 1b4abe4..04b5a79 100644 --- a/testdata/src/go.work +++ b/testdata/src/go.work @@ -1,6 +1,6 @@ go 1.21 use ( + . ./custom - ./tests ) diff --git a/utils.go b/utils.go index 673747f..62d23ce 100644 --- a/utils.go +++ b/utils.go @@ -11,8 +11,7 @@ import ( ) var ( - getwd = os.Getwd - + getwd = os.Getwd commandOutput = func(name string, args ...string) (string, error) { output, err := exec.Command(name, args...).Output() return string(output), err @@ -55,3 +54,21 @@ func getMainModule() (string, error) { } } } + +// based on golang.org/x/tools/imports.VendorlessPath +func cutVendor(path string) string { + var prefix string + switch { + case strings.HasPrefix(path, "(*"): + prefix, path = "(*", path[len("(*"):] + case strings.HasPrefix(path, "("): + prefix, path = "(", path[len("("):] + } + if i := strings.LastIndex(path, "/vendor/"); i >= 0 { + return prefix + path[i+len("/vendor/"):] + } + if strings.HasPrefix(path, "vendor/") { + return prefix + path[len("vendor/"):] + } + return prefix + path +} diff --git a/utils_test.go b/utils_test.go index 5fa8137..1e59151 100644 --- a/utils_test.go +++ b/utils_test.go @@ -8,47 +8,54 @@ import ( ) func Test_getMainModule(t *testing.T) { - test := func(name, want, output string) { - t.Helper() - t.Run(name, func(t *testing.T) { - t.Helper() + tests := map[string]struct { + want, output string + }{ + "single module": { + want: "module1", + output: ` +{"Path": "module1", "Main": true, "Dir": "/path/to/module1"}`, + }, + "multiple modules": { + want: "module1", + output: ` +{"Path": "module1", "Main": true, "Dir": "/path/to/module1"} +{"Path": "module2", "Main": true, "Dir": "/path/to/module2"}`, + }, + } - gwd := getwd - co := commandOutput - defer func() { - getwd = gwd - commandOutput = co - }() + for name, test := range tests { + t.Run(name, func(t *testing.T) { + gwd, co := getwd, commandOutput + defer func() { getwd, commandOutput = gwd, co }() getwd = func() (string, error) { return "/path/to/module1/pkg", nil } - commandOutput = func(name string, args ...string) (string, error) { - return output, nil + commandOutput = func(string, ...string) (string, error) { + return test.output, nil } got, err := getMainModule() assert.NoErr[F](t, err) - assert.Equal[E](t, got, want) + assert.Equal[E](t, got, test.want) }) } - - test("single module", "module1", ` -{ - "Path": "module1", - "Main": true, - "Dir": "/path/to/module1" -}`) - - test("multiple modules", "module1", ` -{ - "Path": "module1", - "Main": true, - "Dir": "/path/to/module1" } -{ - "Path": "module2", - "Main": true, - "Dir": "/path/to/module2" -}`) + +func Test_cutVendor(t *testing.T) { + tests := []struct { + path, want string + }{ + {"foo/bar.A", "foo/bar.A"}, + {"vendor/foo/bar.A", "foo/bar.A"}, + {"test/vendor/foo/bar.A", "foo/bar.A"}, + {"(test/vendor/foo/bar.A).B", "(foo/bar.A).B"}, + {"(*test/vendor/foo/bar.A).B", "(*foo/bar.A).B"}, + } + + for _, test := range tests { + got := cutVendor(test.path) + assert.Equal[E](t, got, test.want) + } }