diff --git a/testdata/src/go.work b/testdata/src/go.work new file mode 100644 index 0000000..5595c7c --- /dev/null +++ b/testdata/src/go.work @@ -0,0 +1,7 @@ +go 1.18 + +use ( + ./builtins + ./custom + ./tests +) diff --git a/utils.go b/utils.go index 09b51a1..673747f 100644 --- a/utils.go +++ b/utils.go @@ -2,29 +2,56 @@ package musttag import ( "encoding/json" + "errors" "fmt" + "io" + "os" "os/exec" "strings" ) +var ( + getwd = os.Getwd + + commandOutput = func(name string, args ...string) (string, error) { + output, err := exec.Command(name, args...).Output() + return string(output), err + } +) + func getMainModule() (string, error) { args := []string{"go", "list", "-m", "-json"} - data, err := exec.Command(args[0], args[1:]...).Output() + output, err := commandOutput(args[0], args[1:]...) if err != nil { return "", fmt.Errorf("running `%s`: %w", strings.Join(args, " "), err) } - var module struct { - Path string `json:"Path"` - Main bool `json:"Main"` - Dir string `json:"Dir"` - GoMod string `json:"GoMod"` - GoVersion string `json:"GoVersion"` - } - if err := json.Unmarshal(data, &module); err != nil { - return "", fmt.Errorf("decoding json: %w: %s", err, string(data)) + cwd, err := getwd() + if err != nil { + return "", fmt.Errorf("getting wd: %w", err) } - return module.Path, nil + decoder := json.NewDecoder(strings.NewReader(output)) + + for { + // multiple JSON objects will be returned when using Go workspaces; see #63 for details. + var module struct { + Path string `json:"Path"` + Main bool `json:"Main"` + Dir string `json:"Dir"` + GoMod string `json:"GoMod"` + GoVersion string `json:"GoVersion"` + } + if err := decoder.Decode(&module); err != nil { + if errors.Is(err, io.EOF) { + return "", fmt.Errorf("main module not found\n%s", output) + } + return "", fmt.Errorf("decoding json: %w\n%s", err, output) + } + + if module.Main && strings.HasPrefix(cwd, module.Dir) { + return module.Path, nil + } + } } diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..aa80671 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,54 @@ +package musttag + +import ( + "testing" + + "go-simpler.org/assert" + . "go-simpler.org/assert/dotimport" +) + +func Test_getMainModule(t *testing.T) { + test := func(name, want, output string) { + t.Helper() + t.Run(name, func(t *testing.T) { + t.Helper() + + gwd := getwd + co := commandOutput + defer func() { + getwd = gwd + commandOutput = co + }() + + getwd = func() (string, error) { + return "/path/to/module1/pkg", nil + } + commandOutput = func(name string, args ...string) (string, error) { + return output, nil + } + + got, err := getMainModule() + assert.NoErr[F](t, err) + assert.Equal[E](t, got, want) + }) + } + + test("single module", "module1", ` +{ + "Path": "module1", + "Main": true, + "Dir": "/path/to/module1" +}`) + + test("multiple modules", "module1", ` +{ + "Path": "module1", + "Main": true, + "Dir": "/path/to/module1" +} +{ + "Path": "module2", + "Main": true, + "Dir": "/path/to/module2" +}`) +}