diff --git a/cmd/slackdump/internal/apiconfig/apiconfig.go b/cmd/slackdump/internal/apiconfig/apiconfig.go index ed548616..2a8caf1c 100644 --- a/cmd/slackdump/internal/apiconfig/apiconfig.go +++ b/cmd/slackdump/internal/apiconfig/apiconfig.go @@ -42,7 +42,8 @@ func Load(filename string) (slackdump.Limits, error) { return readLimits(f) } -func Save(filename string, limits *slackdump.Limits) error { +// Save saves the config to the file. +func Save(filename string, limits slackdump.Limits) error { f, err := os.Create(filename) if err != nil { return err @@ -68,8 +69,8 @@ func readLimits(r io.Reader) (slackdump.Limits, error) { return limits, nil } -func writeLimits(w io.Writer, cfg *slackdump.Limits) error { - return yaml.NewEncoder(w).Encode(&slackdump.DefOptions.Limits) +func writeLimits(w io.Writer, cfg slackdump.Limits) error { + return yaml.NewEncoder(w).Encode(cfg) } // printErrors prints configuration errors, if error is not nill and is of diff --git a/cmd/slackdump/internal/apiconfig/new.go b/cmd/slackdump/internal/apiconfig/new.go index c15fd0cc..a38d33bb 100644 --- a/cmd/slackdump/internal/apiconfig/new.go +++ b/cmd/slackdump/internal/apiconfig/new.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "runtime/trace" "github.com/rusq/slackdump/v2" "github.com/rusq/slackdump/v2/cmd/slackdump/internal/cfg" @@ -38,41 +39,60 @@ func init() { } func runConfigNew(ctx context.Context, cmd *base.Command, args []string) error { + _, task := trace.NewTask(ctx, "runConfigNew") + defer task.End() + if len(args) == 0 { base.SetExitStatus(base.SInvalidParameters) return errors.New("config file name must be specified") } - filename := args[0] - if ext := filepath.Ext(filename); !(ext == ".yaml" || ext == ".yml") { - filename = maybeAddExt(filename, ".yaml") - } + filename := maybeFixExt(args[0]) - if _, err := os.Stat(filename); !*fNewOverride && err == nil { + if !shouldOverwrite(filename, *fNewOverride) { base.SetExitStatus(base.SUserError) - return fmt.Errorf("refusing to overwrite file %q, use -y flag to overwrite", filename) + return fmt.Errorf("file or directory exists: %q, use -y flag to overwrite (will not overwrite directory)", filename) } - if err := Save(filename, &slackdump.DefOptions.Limits); err != nil { + if err := Save(filename, slackdump.DefOptions.Limits); err != nil { base.SetExitStatus(base.SApplicationError) - return fmt.Errorf("error writing the API config %q: %w", filename, err) + return fmt.Errorf("error writing the API limits config %q: %w", filename, err) } - fmt.Printf("Your new API config is ready: %q\n", filename) + fmt.Printf("Your new API limits config is ready: %q\n", filename) return nil } -// maybeAddExt adds a filename extension ext if the filename has missing, or +// shouldOverwrite returns true if the file can be overwritten. If override +// is true and the file exists and not a directory, it will return true. +func shouldOverwrite(filename string, override bool) bool { + fi, err := os.Stat(filename) + if fi != nil && fi.IsDir() { + return false + } + return err != nil || override +} + +// maybeFixExt checks if the extension is one of .yaml or .yml, and if not +// appends it to teh file. +func maybeFixExt(filename string) string { + if ext := filepath.Ext(filename); !(ext == ".yaml" || ext == ".yml") { + return maybeAppendExt(filename, ".yaml") + } + return filename +} + +// maybeAppendExt adds a filename extension ext if the filename has missing, or // a different extension. -func maybeAddExt(filename string, ext string) string { +func maybeAppendExt(filename string, ext string) string { if len(ext) == 0 { return filename } - if filepath.Ext(filename) == ext { - return filename - } if ext[0] != '.' { ext = "." + ext } + if filepath.Ext(filename) == ext { + return filename + } return filename + ext } diff --git a/cmd/slackdump/internal/apiconfig/new_test.go b/cmd/slackdump/internal/apiconfig/new_test.go new file mode 100644 index 00000000..fe4ef584 --- /dev/null +++ b/cmd/slackdump/internal/apiconfig/new_test.go @@ -0,0 +1,203 @@ +package apiconfig + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func Test_maybeAppendExt(t *testing.T) { + type args struct { + filename string + ext string + } + tests := []struct { + name string + args args + want string + }{ + { + "appended", + args{"filename", ".ext"}, + "filename.ext", + }, + { + "empty ext", + args{"no_ext_here", ""}, + "no_ext_here", + }, + { + "dot is prepended to ext", + args{"foo", "bar"}, + "foo.bar", + }, + { + "same ext", + args{"foo.bar", ".bar"}, + "foo.bar", + }, + { + "already has an extension", + args{"filename.xxx", ".ext"}, + "filename.xxx.ext", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := maybeAppendExt(tt.args.filename, tt.args.ext); got != tt.want { + t.Errorf("maybeAppendExt() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_maybeFixExt(t *testing.T) { + type args struct { + filename string + } + tests := []struct { + name string + args args + want string + }{ + { + "already yaml", + args{filename: "lol.yaml"}, + "lol.yaml", + }, + { + "already yml", + args{filename: "lol.yml"}, + "lol.yml", + }, + { + "no extension", + args{filename: "foo"}, + "foo.yaml", + }, + { + "different extension", + args{filename: "foo.bar"}, + "foo.bar.yaml", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := maybeFixExt(tt.args.filename); got != tt.want { + t.Errorf("maybeFixExt() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_shouldOverwrite(t *testing.T) { + dir := t.TempDir() + existingFile, err := os.CreateTemp(dir, "unittest*") + if err != nil { + t.Fatal(err) + } + defer existingFile.Close() + + existingDir := filepath.Join(dir, "existing_dir") + if err := os.Mkdir(existingDir, 0755); err != nil { + t.Fatal(err) + } + type args struct { + filename string + override bool + } + tests := []struct { + name string + args args + want bool + }{ + { + "non-existing file", + args{"$$$$", false}, + true, + }, + { + "non-existing file override", + args{"$$$$", true}, + true, + }, + { + "existing file", + args{existingFile.Name(), false}, + false, + }, + { + "existing file override", + args{existingFile.Name(), true}, + true, + }, + { + "existing directory", + args{existingDir, false}, + false, + }, + { + "existing directory override", + args{existingDir, true}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := shouldOverwrite(tt.args.filename, tt.args.override); got != tt.want { + t.Errorf("shouldOverwrite() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_runConfigNew(t *testing.T) { + dir := t.TempDir() + existingDir := filepath.Join(dir, "test.yaml") + if err := os.MkdirAll(existingDir, 0777); err != nil { + t.Fatal(err) + } + type args struct { + args []string + } + tests := []struct { + name string + args args + wantErr bool + shouldExist bool + }{ + { + "no arguments given", + args{}, + true, + false, + }, + { + "file is created", + args{[]string{filepath.Join(dir, "sample.yml")}}, + false, + true, + }, + { + "directory test.yaml", + args{[]string{existingDir}}, + true, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := runConfigNew(context.Background(), CmdConfigNew, tt.args.args); (err != nil) != tt.wantErr { + t.Errorf("runConfigNew() error = %v, wantErr %v", err, tt.wantErr) + } + if len(tt.args.args) == 0 { + return + } + _, err := os.Stat(tt.args.args[0]) + if (err == nil) != tt.shouldExist { + t.Errorf("file exist error: %s, shouldExist = %v", err, tt.shouldExist) + } + }) + } +}