Skip to content

Commit

Permalink
refactor: use cutVendor() instead of regexp and move vendor (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane authored Oct 27, 2023
1 parent 9307251 commit dde3e87
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 43 deletions.
10 changes: 3 additions & 7 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"go/ast"
"go/types"
"reflect"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -81,8 +80,6 @@ func flags(funcs *[]Func) flag.FlagSet {
return *fs
}

var trimVendor = regexp.MustCompile("([^*/(]+/vendor/)")

func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, err error) {
visit := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
filter := []ast.Node{(*ast.CallExpr)(nil)}
Expand All @@ -102,8 +99,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
return // not a static call.
}

name := trimVendor.ReplaceAllString(callee.FullName(), "")
fn, ok := funcs[name]
fn, ok := funcs[cutVendor(callee.FullName())]
if !ok {
return // unsupported function.
}
Expand Down Expand Up @@ -221,15 +217,15 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa
findScope := func(pkgName string) (*types.Scope, bool) {
// fast path: check direct imports (e.g. looking for "encoding/json.Marshaler").
for _, direct := range imports {
if pkgName == trimVendor.ReplaceAllString(direct.Path(), "") {
if pkgName == cutVendor(direct.Path()) {
return direct.Scope(), true
}
}
// slow path: check indirect imports (e.g. looking for "encoding.TextMarshaler").
// TODO: only check indirect imports from the package (e.g. "encoding/json") of the analyzed function (e.g. "encoding/json.Marshal").
for _, direct := range imports {
for _, indirect := range direct.Imports() {
if pkgName == trimVendor.ReplaceAllString(indirect.Path(), "") {
if pkgName == cutVendor(indirect.Path()) {
return indirect.Scope(), true
}
}
Expand Down
2 changes: 1 addition & 1 deletion musttag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (nopT) Errorf(string, ...any) {}
func setupModules(t *testing.T, testdata string) {
t.Helper()

err := os.Chdir(filepath.Join(testdata, "src", "tests"))
err := os.Chdir(filepath.Join(testdata, "src"))
assert.NoErr[F](t, err)

err = exec.Command("go", "mod", "vendor").Run()
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion testdata/src/tests/go.mod → testdata/src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ require (
gopkg.in/yaml.v3 v3.0.1
)

replace example.com/custom => ../custom
replace example.com/custom => ./custom
File renamed without changes.
2 changes: 1 addition & 1 deletion testdata/src/go.work
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
go 1.21

use (
.
./custom
./tests
)
21 changes: 19 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
)

var (
getwd = os.Getwd

getwd = os.Getwd
commandOutput = func(name string, args ...string) (string, error) {
output, err := exec.Command(name, args...).Output()
return string(output), err
Expand Down Expand Up @@ -55,3 +54,21 @@ func getMainModule() (string, error) {
}
}
}

// based on golang.org/x/tools/imports.VendorlessPath
func cutVendor(path string) string {
var prefix string
switch {
case strings.HasPrefix(path, "(*"):
prefix, path = "(*", path[len("(*"):]
case strings.HasPrefix(path, "("):
prefix, path = "(", path[len("("):]
}
if i := strings.LastIndex(path, "/vendor/"); i >= 0 {
return prefix + path[i+len("/vendor/"):]
}
if strings.HasPrefix(path, "vendor/") {
return prefix + path[len("vendor/"):]
}
return prefix + path
}
69 changes: 38 additions & 31 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,54 @@ import (
)

func Test_getMainModule(t *testing.T) {
test := func(name, want, output string) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tests := map[string]struct {
want, output string
}{
"single module": {
want: "module1",
output: `
{"Path": "module1", "Main": true, "Dir": "/path/to/module1"}`,
},
"multiple modules": {
want: "module1",
output: `
{"Path": "module1", "Main": true, "Dir": "/path/to/module1"}
{"Path": "module2", "Main": true, "Dir": "/path/to/module2"}`,
},
}

gwd := getwd
co := commandOutput
defer func() {
getwd = gwd
commandOutput = co
}()
for name, test := range tests {
t.Run(name, func(t *testing.T) {
gwd, co := getwd, commandOutput
defer func() { getwd, commandOutput = gwd, co }()

getwd = func() (string, error) {
return "/path/to/module1/pkg", nil
}
commandOutput = func(name string, args ...string) (string, error) {
return output, nil
commandOutput = func(string, ...string) (string, error) {
return test.output, nil
}

got, err := getMainModule()
assert.NoErr[F](t, err)
assert.Equal[E](t, got, want)
assert.Equal[E](t, got, test.want)
})
}

test("single module", "module1", `
{
"Path": "module1",
"Main": true,
"Dir": "/path/to/module1"
}`)

test("multiple modules", "module1", `
{
"Path": "module1",
"Main": true,
"Dir": "/path/to/module1"
}
{
"Path": "module2",
"Main": true,
"Dir": "/path/to/module2"
}`)

func Test_cutVendor(t *testing.T) {
tests := []struct {
path, want string
}{
{"foo/bar.A", "foo/bar.A"},
{"vendor/foo/bar.A", "foo/bar.A"},
{"test/vendor/foo/bar.A", "foo/bar.A"},
{"(test/vendor/foo/bar.A).B", "(foo/bar.A).B"},
{"(*test/vendor/foo/bar.A).B", "(*foo/bar.A).B"},
}

for _, test := range tests {
got := cutVendor(test.path)
assert.Equal[E](t, got, test.want)
}
}

0 comments on commit dde3e87

Please sign in to comment.