Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use cutVendor() instead of regexp and move vendor #71

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"go/ast"
"go/types"
"reflect"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -81,8 +80,6 @@ func flags(funcs *[]Func) flag.FlagSet {
return *fs
}

var trimVendor = regexp.MustCompile("([^*/(]+/vendor/)")

func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any, err error) {
visit := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
filter := []ast.Node{(*ast.CallExpr)(nil)}
Expand All @@ -102,8 +99,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
return // not a static call.
}

name := trimVendor.ReplaceAllString(callee.FullName(), "")
fn, ok := funcs[name]
fn, ok := funcs[cutVendor(callee.FullName())]
if !ok {
return // unsupported function.
}
Expand Down Expand Up @@ -221,15 +217,15 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa
findScope := func(pkgName string) (*types.Scope, bool) {
// fast path: check direct imports (e.g. looking for "encoding/json.Marshaler").
for _, direct := range imports {
if pkgName == trimVendor.ReplaceAllString(direct.Path(), "") {
if pkgName == cutVendor(direct.Path()) {
return direct.Scope(), true
}
}
// slow path: check indirect imports (e.g. looking for "encoding.TextMarshaler").
// TODO: only check indirect imports from the package (e.g. "encoding/json") of the analyzed function (e.g. "encoding/json.Marshal").
for _, direct := range imports {
for _, indirect := range direct.Imports() {
if pkgName == trimVendor.ReplaceAllString(indirect.Path(), "") {
if pkgName == cutVendor(indirect.Path()) {
return indirect.Scope(), true
}
}
Expand Down
2 changes: 1 addition & 1 deletion musttag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (nopT) Errorf(string, ...any) {}
func setupModules(t *testing.T, testdata string) {
t.Helper()

err := os.Chdir(filepath.Join(testdata, "src", "tests"))
err := os.Chdir(filepath.Join(testdata, "src"))
assert.NoErr[F](t, err)

err = exec.Command("go", "mod", "vendor").Run()
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion testdata/src/tests/go.mod → testdata/src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ require (
gopkg.in/yaml.v3 v3.0.1
)

replace example.com/custom => ../custom
replace example.com/custom => ./custom
File renamed without changes.
2 changes: 1 addition & 1 deletion testdata/src/go.work
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
go 1.21

use (
.
./custom
./tests
)
21 changes: 19 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
)

var (
getwd = os.Getwd

getwd = os.Getwd
commandOutput = func(name string, args ...string) (string, error) {
output, err := exec.Command(name, args...).Output()
return string(output), err
Expand Down Expand Up @@ -55,3 +54,21 @@ func getMainModule() (string, error) {
}
}
}

// based on golang.org/x/tools/imports.VendorlessPath
func cutVendor(path string) string {
var prefix string
switch {
case strings.HasPrefix(path, "(*"):
prefix, path = "(*", path[len("(*"):]
case strings.HasPrefix(path, "("):
prefix, path = "(", path[len("("):]
}
if i := strings.LastIndex(path, "/vendor/"); i >= 0 {
return prefix + path[i+len("/vendor/"):]
}
if strings.HasPrefix(path, "vendor/") {
return prefix + path[len("vendor/"):]
}
return prefix + path
}
69 changes: 38 additions & 31 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,54 @@ import (
)

func Test_getMainModule(t *testing.T) {
test := func(name, want, output string) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tests := map[string]struct {
want, output string
}{
"single module": {
want: "module1",
output: `
{"Path": "module1", "Main": true, "Dir": "/path/to/module1"}`,
},
"multiple modules": {
want: "module1",
output: `
{"Path": "module1", "Main": true, "Dir": "/path/to/module1"}
{"Path": "module2", "Main": true, "Dir": "/path/to/module2"}`,
},
}

gwd := getwd
co := commandOutput
defer func() {
getwd = gwd
commandOutput = co
}()
for name, test := range tests {
t.Run(name, func(t *testing.T) {
gwd, co := getwd, commandOutput
defer func() { getwd, commandOutput = gwd, co }()

getwd = func() (string, error) {
return "/path/to/module1/pkg", nil
}
commandOutput = func(name string, args ...string) (string, error) {
return output, nil
commandOutput = func(string, ...string) (string, error) {
return test.output, nil
}

got, err := getMainModule()
assert.NoErr[F](t, err)
assert.Equal[E](t, got, want)
assert.Equal[E](t, got, test.want)
})
}

test("single module", "module1", `
{
"Path": "module1",
"Main": true,
"Dir": "/path/to/module1"
}`)

test("multiple modules", "module1", `
{
"Path": "module1",
"Main": true,
"Dir": "/path/to/module1"
}
{
"Path": "module2",
"Main": true,
"Dir": "/path/to/module2"
}`)

func Test_cutVendor(t *testing.T) {
tests := []struct {
path, want string
}{
{"foo/bar.A", "foo/bar.A"},
{"vendor/foo/bar.A", "foo/bar.A"},
{"test/vendor/foo/bar.A", "foo/bar.A"},
{"(test/vendor/foo/bar.A).B", "(foo/bar.A).B"},
{"(*test/vendor/foo/bar.A).B", "(*foo/bar.A).B"},
}

for _, test := range tests {
got := cutVendor(test.path)
assert.Equal[E](t, got, test.want)
}
}
Loading