diff --git a/cmp/cmpopts/equate.go b/cmp/cmpopts/equate.go index 90974e6..3d8d0cd 100644 --- a/cmp/cmpopts/equate.go +++ b/cmp/cmpopts/equate.go @@ -7,6 +7,7 @@ package cmpopts import ( "errors" + "fmt" "math" "reflect" "time" @@ -154,3 +155,31 @@ func compareErrors(x, y interface{}) bool { ye := y.(error) return errors.Is(xe, ye) || errors.Is(ye, xe) } + +// EquateComparable returns a [cmp.Option] that determines equality +// of comparable types by directly comparing them using the == operator in Go. +// The types to compare are specified by passing a value of that type. +// This option should only be used on types that are documented as being +// safe for direct == comparison. For example, [net/netip.Addr] is documented +// as being semantically safe to use with ==, while [time.Time] is documented +// to discourage the use of == on time values. +func EquateComparable(typs ...interface{}) cmp.Option { + types := make(typesFilter) + for _, typ := range typs { + switch t := reflect.TypeOf(typ); { + case !t.Comparable(): + panic(fmt.Sprintf("%T is not a comparable Go type", typ)) + case types[t]: + panic(fmt.Sprintf("%T is already specified", typ)) + default: + types[t] = true + } + } + return cmp.FilterPath(types.filter, cmp.Comparer(equateAny)) +} + +type typesFilter map[reflect.Type]bool + +func (tf typesFilter) filter(p cmp.Path) bool { return tf[p.Last().Type()] } + +func equateAny(x, y interface{}) bool { return x == y } diff --git a/cmp/cmpopts/util_test.go b/cmp/cmpopts/util_test.go index 7adeb9b..6a7c300 100644 --- a/cmp/cmpopts/util_test.go +++ b/cmp/cmpopts/util_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "math" + "net/netip" "reflect" "strings" "sync" @@ -676,6 +677,36 @@ func TestOptions(t *testing.T) { opts: []cmp.Option{EquateErrors()}, wantEqual: false, reason: "AnyError is not equal to nil value", + }, { + label: "EquateComparable", + x: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + y: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + opts: []cmp.Option{EquateComparable(netip.Addr{})}, + wantEqual: true, + reason: "equal because all IP addresses are the same", + }, { + label: "EquateComparable", + x: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 5})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + y: []struct{ P netip.Addr }{ + {netip.AddrFrom4([4]byte{1, 2, 3, 4})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 7})}, + {netip.AddrFrom4([4]byte{1, 2, 3, 6})}, + }, + opts: []cmp.Option{EquateComparable(netip.Addr{})}, + wantEqual: false, + reason: "not equal because second IP address is different", }, { label: "IgnoreFields", x: Bar1{Foo3{&Foo2{&Foo1{Alpha: 5}}}}, diff --git a/cmp/options.go b/cmp/options.go index 518b6ac..392a1ce 100644 --- a/cmp/options.go +++ b/cmp/options.go @@ -232,7 +232,9 @@ func (validator) apply(s *state, vx, vy reflect.Value) { if t := s.curPath.Index(-2).Type(); t.Name() != "" { // Named type with unexported fields. name = fmt.Sprintf("%q.%v", t.PkgPath(), t.Name()) // e.g., "path/to/package".MyType - if _, ok := reflect.New(t).Interface().(error); ok { + if t.Comparable() { + help = "consider using cmpopts.EquateComparable to compare comparable Go types" + } else if _, ok := reflect.New(t).Interface().(error); ok { help = "consider using cmpopts.EquateErrors to compare error values" } } else {