Skip to content

Commit

Permalink
When pulling a default parameter value out of a Param via type assert…
Browse files Browse the repository at this point in the history
…ion, try all valid types. Fixes #4. (#5)
  • Loading branch information
bobg authored Nov 27, 2023
1 parent 9489753 commit 301b833
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 60 deletions.
205 changes: 145 additions & 60 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ func parseBoolPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val int64
if dflt, ok := p.Default.(int); ok {
val = int64(dflt)
} else if dflt, ok := p.Default.(int64); ok {
val = dflt
}
val := int64(asInt(p.Default)) // has to be int64 to receive the result of ParseInt below

if len(*args) > 0 {
var err error
Expand All @@ -128,12 +123,7 @@ func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val int64
if dflt, ok := p.Default.(int); ok {
val = int64(dflt)
} else if dflt, ok := p.Default.(int64); ok {
val = dflt
}
val := asInt64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -148,12 +138,7 @@ func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val uint64
if dflt, ok := p.Default.(uint); ok {
val = uint64(dflt)
} else if dflt, ok := p.Default.(uint64); ok {
val = dflt
}
val := uint64(asUint(p.Default)) // has to be uint64 to receive the result of ParseUint below

if len(*args) > 0 {
var err error
Expand All @@ -168,12 +153,7 @@ func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseUint64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val uint64
if dflt, ok := p.Default.(uint); ok {
val = uint64(dflt)
} else if dflt, ok := p.Default.(uint64); ok {
val = dflt
}
val := asUint64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -198,21 +178,7 @@ func parseStringPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val float64
switch dflt := p.Default.(type) {
case int:
val = float64(dflt)
case int64:
val = float64(dflt)
case uint:
val = float64(dflt)
case uint64:
val = float64(dflt)
case float32:
val = float64(dflt)
case float64:
val = dflt
}
val := asFloat64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -227,15 +193,7 @@ func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseDurationPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val time.Duration
switch dflt := p.Default.(type) {
case int:
val = time.Duration(dflt)
case int64:
val = time.Duration(dflt)
case time.Duration:
val = dflt
}
val := asDuration(p.Default)

if len(*args) > 0 {
var err error
Expand Down Expand Up @@ -264,6 +222,139 @@ func parseValuePos(args *[]string, argvals *[]reflect.Value, p Param) error {
return nil
}

func asInt(val interface{}) int {
switch v := val.(type) {
case int:
return v

case int8:
return int(v)
case int16:
return int(v)
case int32:
return int(v)

case uint8:
return int(v)
case uint16:
return int(v)
}
return 0
}

func asInt64(val interface{}) int64 {
switch v := val.(type) {
case int:
return int64(v)
case int64:
return v

case int8:
return int64(v)
case int16:
return int64(v)
case int32:
return int64(v)

case uint8:
return int64(v)
case uint16:
return int64(v)
case uint32:
return int64(v)
}
return 0
}

func asUint(val interface{}) uint {
switch v := val.(type) {
case uint:
return v

case int8:
return uint(v)
case int16:
return uint(v)

case uint8:
return uint(v)
case uint16:
return uint(v)
case uint32:
return uint(v)
}
return 0
}

func asUint64(val interface{}) uint64 {
switch v := val.(type) {
case uint:
return uint64(v)
case uint64:
return v

case int8:
return uint64(v)
case int16:
return uint64(v)
case int32:
return uint64(v)

case uint8:
return uint64(v)
case uint16:
return uint64(v)
case uint32:
return uint64(v)
}
return 0
}

func asFloat64(val interface{}) float64 {
switch v := val.(type) {
case int:
return float64(v)
case uint:
return float64(v)

case int8:
return float64(v)
case int16:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)

case uint8:
return float64(v)
case uint16:
return float64(v)
case uint32:
return float64(v)
case uint64:
return float64(v)

case float32:
return float64(v)
case float64:
return v
}
return 0
}

func asDuration(val interface{}) time.Duration {
switch v := val.(type) {
case int:
return time.Duration(v)
case int64:
return time.Duration(v)
case time.Duration:
return v
}
return 0
}

// ToFlagSet takes a slice of [Param] and produces:
//
// - a [flag.FlagSet],
Expand Down Expand Up @@ -291,32 +382,26 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position
v = fs.Bool(name, dflt, p.Doc)

case Int:
dflt, _ := p.Default.(int)
v = fs.Int(name, dflt, p.Doc)
v = fs.Int(name, asInt(p.Default), p.Doc)

case Int64:
dflt, _ := p.Default.(int64)
v = fs.Int64(name, dflt, p.Doc)
v = fs.Int64(name, asInt64(p.Default), p.Doc)

case Uint:
dflt, _ := p.Default.(uint)
v = fs.Uint(name, dflt, p.Doc)
v = fs.Uint(name, asUint(p.Default), p.Doc)

case Uint64:
dflt, _ := p.Default.(uint64)
v = fs.Uint64(name, dflt, p.Doc)
v = fs.Uint64(name, asUint64(p.Default), p.Doc)

case String:
dflt, _ := p.Default.(string)
v = fs.String(name, dflt, p.Doc)

case Float64:
dflt, _ := p.Default.(float64)
v = fs.Float64(name, dflt, p.Doc)
v = fs.Float64(name, asFloat64(p.Default), p.Doc)

case Duration:
dflt, _ := p.Default.(time.Duration)
v = fs.Duration(name, dflt, p.Doc)
v = fs.Duration(name, asDuration(p.Default), p.Doc)

case Value:
val, ok := p.Default.(flag.Value)
Expand Down
45 changes: 45 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package subcmd

import (
"flag"
"testing"
)

func TestToFlagSet(t *testing.T) {
cases := []struct {
name string
params []Param
wantfs map[string]interface{}
}{{
name: "float64_with_int_default",
params: []Param{{
Name: "-float64",
Type: Float64,
Default: 1,
}},
wantfs: map[string]interface{}{
"float64": 1.0,
},
}}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
fs, _, _, err := ToFlagSet(c.params)
if err != nil {
t.Fatal(err)
}

for k, v := range c.wantfs {
val := fs.Lookup(k)
if val == nil {
t.Fatalf("flag %s not found", k)
}
getter := val.Value.(flag.Getter)
got := getter.Get()
if got != v {
t.Errorf("got %v (%T), want %v (%T)", got, got, v, v)
}
}
})
}
}

0 comments on commit 301b833

Please sign in to comment.