From 58aedaa03a69173b60a077b63cf49eb72e78998b Mon Sep 17 00:00:00 2001 From: Siqing Zheng Date: Wed, 21 Aug 2024 16:31:24 -0700 Subject: [PATCH] add doNotRecurseSignalErr --- example_test.go | 35 +++++++++++++++++++++++++++++++---- walkstruct.go | 34 +++++++++++++++++++++++++++++----- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/example_test.go b/example_test.go index 97e10a9..0bd6c23 100644 --- a/example_test.go +++ b/example_test.go @@ -49,14 +49,34 @@ func makeIntDoublerWithError(t reflect.Type) (func(v reflect.Value), error) { panic("makeIntDoublerWithError only supports pointers to structs") } var ints []reflect.StructField - err := reflectutils.WalkStructElementsWithError(t, func(f reflect.StructField) (bool, error) { + err := reflectutils.WalkStructElementsWithError(t, func(f reflect.StructField) error { if f.Type.Kind() == reflect.Int { ints = append(ints, f) } if f.Type.Kind() == reflect.String { - return false, errors.Errorf("error on string") + return errors.Errorf("error on string, and stop walk") } - return true, nil + return nil // default nil recursive down + }) + return func(v reflect.Value) { + v = v.Elem() + for _, f := range ints { + i := v.FieldByIndex(f.Index) + i.SetInt(int64(i.Interface().(int)) * 2) + } + }, err +} + +func makeIntDoublerNoRecursive(t reflect.Type) (func(v reflect.Value), error) { + if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { + panic("makeIntDoublerNoRecursive only supports pointers to structs") + } + var ints []reflect.StructField + err := reflectutils.WalkStructElementsWithError(t, func(f reflect.StructField) error { + if f.Type.Kind() == reflect.Int { + ints = append(ints, f) + } + return reflectutils.DoNotRecurseSignalErr }) return func(v reflect.Value) { v = v.Elem() @@ -84,7 +104,14 @@ func Example() { doubler(v) fmt.Printf("%v\n", v.Interface()) + doubler, err := makeIntDoublerNoRecursive(v.Type()) + doubler(v) + fmt.Printf("%v\n", v.Interface()) + fmt.Printf("%v", err) + // Output: &{6 {4 12} string {10}} + // &{12 {4 12} string {10}} + // } func ExampleError() { @@ -106,5 +133,5 @@ func ExampleError() { fmt.Printf("%v", err) // Output: &{6 {4 12} string {5}} - // error on string + // error on string, and stop walk } diff --git a/walkstruct.go b/walkstruct.go index 243f51d..7891d4d 100644 --- a/walkstruct.go +++ b/walkstruct.go @@ -1,6 +1,7 @@ package reflectutils import ( + "errors" "reflect" ) @@ -37,7 +38,23 @@ func doWalkStructElements(t reflect.Type, path []int, f func(reflect.StructField } } -func WalkStructElementsWithError(t reflect.Type, f func(reflect.StructField) (bool, error)) error { +// WalkStructElementsWithError recursively visits the fields in a structure calling a +// callback for each field. It modifies the reflect.StructField object +// so that Index is relative to the root object originally passed to +// WalkStructElementsWithError. This allows the FieldByIndex method on a reflect.Value +// matching the original struct to fetch that field. +// +// WalkStructElementsWithError should be called with a reflect.Type whose Kind() is +// reflect.Struct or whose Kind() is reflect.Ptr and Elem.Type() is reflect.Struct. +// All other types will simply be ignored. +// +// A non-nil return value from the called function stops iteration and recursion and becomes +// the return value. +// +// A special error return value, [DoNotRecurseSignalErr] is not considered an error (it will +// not become the return value, and it does not stop iteration) but it will stop recursion if returned +// on a field that is itself a struct. +func WalkStructElementsWithError(t reflect.Type, f func(reflect.StructField) error) error { if t.Kind() == reflect.Struct { return doWalkStructElementsWithError(t, []int{}, f) } @@ -47,16 +64,23 @@ func WalkStructElementsWithError(t reflect.Type, f func(reflect.StructField) (bo return nil } -func doWalkStructElementsWithError(t reflect.Type, path []int, f func(reflect.StructField) (bool, error)) error { +var DoNotRecurseSignalErr = errors.New("walkstruct: do not recurse signal") + +func doWalkStructElementsWithError(t reflect.Type, path []int, f func(reflect.StructField) error) error { for i := 0; i < t.NumField(); i++ { field := t.Field(i) np := copyIntSlice(path) np = append(np, field.Index...) field.Index = np - if walkDown, err := f(field); err != nil { + err := f(field) + if errors.Is(err, DoNotRecurseSignalErr) { + continue + } + if err != nil { return err - } else if walkDown && field.Type.Kind() == reflect.Struct { - err := doWalkStructElementsWithError(field.Type, np, f) + } + if field.Type.Kind() == reflect.Struct { + err = doWalkStructElementsWithError(field.Type, np, f) if err != nil { return err }