diff --git a/parse.go b/parse.go index cbcb9ad..a3eca29 100644 --- a/parse.go +++ b/parse.go @@ -108,12 +108,7 @@ func parseBoolPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val int64 - if dflt, ok := p.Default.(int); ok { - val = int64(dflt) - } else if dflt, ok := p.Default.(int64); ok { - val = dflt - } + val := int64(asInt(p.Default)) // has to be int64 to receive the result of ParseInt below if len(*args) > 0 { var err error @@ -128,12 +123,7 @@ func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val int64 - if dflt, ok := p.Default.(int); ok { - val = int64(dflt) - } else if dflt, ok := p.Default.(int64); ok { - val = dflt - } + val := asInt64(p.Default) if len(*args) > 0 { var err error @@ -148,12 +138,7 @@ func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val uint64 - if dflt, ok := p.Default.(uint); ok { - val = uint64(dflt) - } else if dflt, ok := p.Default.(uint64); ok { - val = dflt - } + val := uint64(asUint(p.Default)) // has to be uint64 to receive the result of ParseUint below if len(*args) > 0 { var err error @@ -168,12 +153,7 @@ func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseUint64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val uint64 - if dflt, ok := p.Default.(uint); ok { - val = uint64(dflt) - } else if dflt, ok := p.Default.(uint64); ok { - val = dflt - } + val := asUint64(p.Default) if len(*args) > 0 { var err error @@ -198,21 +178,7 @@ func parseStringPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val float64 - switch dflt := p.Default.(type) { - case int: - val = float64(dflt) - case int64: - val = float64(dflt) - case uint: - val = float64(dflt) - case uint64: - val = float64(dflt) - case float32: - val = float64(dflt) - case float64: - val = dflt - } + val := asFloat64(p.Default) if len(*args) > 0 { var err error @@ -227,15 +193,7 @@ func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseDurationPos(args *[]string, argvals *[]reflect.Value, p Param) error { - var val time.Duration - switch dflt := p.Default.(type) { - case int: - val = time.Duration(dflt) - case int64: - val = time.Duration(dflt) - case time.Duration: - val = dflt - } + val := asDuration(p.Default) if len(*args) > 0 { var err error @@ -264,6 +222,139 @@ func parseValuePos(args *[]string, argvals *[]reflect.Value, p Param) error { return nil } +func asInt(val interface{}) int { + switch v := val.(type) { + case int: + return v + + case int8: + return int(v) + case int16: + return int(v) + case int32: + return int(v) + + case uint8: + return int(v) + case uint16: + return int(v) + } + return 0 +} + +func asInt64(val interface{}) int64 { + switch v := val.(type) { + case int: + return int64(v) + case int64: + return v + + case int8: + return int64(v) + case int16: + return int64(v) + case int32: + return int64(v) + + case uint8: + return int64(v) + case uint16: + return int64(v) + case uint32: + return int64(v) + } + return 0 +} + +func asUint(val interface{}) uint { + switch v := val.(type) { + case uint: + return v + + case int8: + return uint(v) + case int16: + return uint(v) + + case uint8: + return uint(v) + case uint16: + return uint(v) + case uint32: + return uint(v) + } + return 0 +} + +func asUint64(val interface{}) uint64 { + switch v := val.(type) { + case uint: + return uint64(v) + case uint64: + return v + + case int8: + return uint64(v) + case int16: + return uint64(v) + case int32: + return uint64(v) + + case uint8: + return uint64(v) + case uint16: + return uint64(v) + case uint32: + return uint64(v) + } + return 0 +} + +func asFloat64(val interface{}) float64 { + switch v := val.(type) { + case int: + return float64(v) + case uint: + return float64(v) + + case int8: + return float64(v) + case int16: + return float64(v) + case int32: + return float64(v) + case int64: + return float64(v) + + case uint8: + return float64(v) + case uint16: + return float64(v) + case uint32: + return float64(v) + case uint64: + return float64(v) + + case float32: + return float64(v) + case float64: + return v + } + return 0 +} + +func asDuration(val interface{}) time.Duration { + switch v := val.(type) { + case int: + return time.Duration(v) + case int64: + return time.Duration(v) + case time.Duration: + return v + } + return 0 +} + // ToFlagSet takes a slice of [Param] and produces: // // - a [flag.FlagSet], @@ -291,32 +382,26 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position v = fs.Bool(name, dflt, p.Doc) case Int: - dflt, _ := p.Default.(int) - v = fs.Int(name, dflt, p.Doc) + v = fs.Int(name, asInt(p.Default), p.Doc) case Int64: - dflt, _ := p.Default.(int64) - v = fs.Int64(name, dflt, p.Doc) + v = fs.Int64(name, asInt64(p.Default), p.Doc) case Uint: - dflt, _ := p.Default.(uint) - v = fs.Uint(name, dflt, p.Doc) + v = fs.Uint(name, asUint(p.Default), p.Doc) case Uint64: - dflt, _ := p.Default.(uint64) - v = fs.Uint64(name, dflt, p.Doc) + v = fs.Uint64(name, asUint64(p.Default), p.Doc) case String: dflt, _ := p.Default.(string) v = fs.String(name, dflt, p.Doc) case Float64: - dflt, _ := p.Default.(float64) - v = fs.Float64(name, dflt, p.Doc) + v = fs.Float64(name, asFloat64(p.Default), p.Doc) case Duration: - dflt, _ := p.Default.(time.Duration) - v = fs.Duration(name, dflt, p.Doc) + v = fs.Duration(name, asDuration(p.Default), p.Doc) case Value: val, ok := p.Default.(flag.Value) diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..eb78351 --- /dev/null +++ b/parse_test.go @@ -0,0 +1,45 @@ +package subcmd + +import ( + "flag" + "testing" +) + +func TestToFlagSet(t *testing.T) { + cases := []struct { + name string + params []Param + wantfs map[string]interface{} + }{{ + name: "float64_with_int_default", + params: []Param{{ + Name: "-float64", + Type: Float64, + Default: 1, + }}, + wantfs: map[string]interface{}{ + "float64": 1.0, + }, + }} + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + fs, _, _, err := ToFlagSet(c.params) + if err != nil { + t.Fatal(err) + } + + for k, v := range c.wantfs { + val := fs.Lookup(k) + if val == nil { + t.Fatalf("flag %s not found", k) + } + getter := val.Value.(flag.Getter) + got := getter.Get() + if got != v { + t.Errorf("got %v (%T), want %v (%T)", got, got, v, v) + } + } + }) + } +}