Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support compare functions with SortSlices and SortMaps #367

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x]
go-version: [1.21.x]
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
64 changes: 44 additions & 20 deletions cmp/cmpopts/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@ import (
)

// SortSlices returns a [cmp.Transformer] option that sorts all []V.
// The less function must be of the form "func(T, T) bool" which is used to
// sort any slice with element type V that is assignable to T.
// The lessOrCompareFunc function must be either
// a less function of the form "func(T, T) bool" or
// a compare function of the format "func(T, T) int"
// which is used to sort any slice with element type V that is assignable to T.
//
// The less function must be:
// A less function must be:
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
//
// The less function does not have to be "total". That is, if !less(x, y) and
// !less(y, x) for two elements x and y, their relative order is maintained.
// A compare function must be:
// - Deterministic: compare(x, y) == compare(x, y)
// - Irreflexive: compare(x, x) == 0
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
//
// The function does not have to be "total". That is, if x != y, but
// less or compare report inequality, their relative order is maintained.
//
// SortSlices can be used in conjunction with [EquateEmpty].
func SortSlices(lessFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessFunc)
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
func SortSlices(lessOrCompareFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessOrCompareFunc)
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
}
ss := sliceSorter{vf.Type().In(0), vf}
return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort))
Expand Down Expand Up @@ -79,28 +86,40 @@ func (ss sliceSorter) checkSort(v reflect.Value) {
}
func (ss sliceSorter) less(v reflect.Value, i, j int) bool {
vx, vy := v.Index(i), v.Index(j)
return ss.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
vo := ss.fnc.Call([]reflect.Value{vx, vy})[0]
if vo.Kind() == reflect.Bool {
return vo.Bool()
} else {
return vo.Int() < 0
}
}

// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be a
// sorted []struct{K, V}. The less function must be of the form
// "func(T, T) bool" which is used to sort any map with key K that is
// assignable to T.
// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be
// a sorted []struct{K, V}. The lessOrCompareFunc function must be either
// a less function of the form "func(T, T) bool" or
// a compare function of the format "func(T, T) int"
// which is used to sort any map with key K that is assignable to T.
//
// Flattening the map into a slice has the property that [cmp.Equal] is able to
// use [cmp.Comparer] options on K or the K.Equal method if it exists.
//
// The less function must be:
// A less function must be:
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
// - Total: if x != y, then either less(x, y) or less(y, x)
//
// A compare function must be:
// - Deterministic: compare(x, y) == compare(x, y)
// - Irreflexive: compare(x, x) == 0
// - Transitive: if compare(x, y) < 0 and compare(y, z) < 0, then compare(x, z) < 0
// - Total: if x != y, then compare(x, y) != 0
//
// SortMaps can be used in conjunction with [EquateEmpty].
func SortMaps(lessFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessFunc)
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
func SortMaps(lessOrCompareFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessOrCompareFunc)
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
}
ms := mapSorter{vf.Type().In(0), vf}
return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort))
Expand Down Expand Up @@ -143,5 +162,10 @@ func (ms mapSorter) checkSort(v reflect.Value) {
}
func (ms mapSorter) less(v reflect.Value, i, j int) bool {
vx, vy := v.Index(i).Field(0), v.Index(j).Field(0)
return ms.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
vo := ms.fnc.Call([]reflect.Value{vx, vy})[0]
if vo.Kind() == reflect.Bool {
return vo.Bool()
} else {
return vo.Int() < 0
}
}
48 changes: 34 additions & 14 deletions cmp/cmpopts/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,23 @@ func TestOptions(t *testing.T) {
opts: []cmp.Option{SortSlices(func(x, y int) bool { return x < y })},
wantEqual: true,
reason: "equal because SortSlices sorts the slices",
}, {
label: "SortSlices",
x: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
y: []int{1, 0, 5, 2, 8, 9, 4, 3, 6, 7},
opts: []cmp.Option{SortSlices(func(x, y int) int {
// TODO(Go1.22): Use cmp.Compare.
switch {
case x < y:
return -1
case y > x:
return +1
default:
return 0
}
})},
wantEqual: true,
reason: "equal because SortSlices sorts the slices",
}, {
label: "SortSlices",
x: []MyInt{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
Expand Down Expand Up @@ -201,6 +218,21 @@ func TestOptions(t *testing.T) {
opts: []cmp.Option{SortMaps(func(x, y time.Time) bool { return x.Before(y) })},
wantEqual: true,
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
}, {
label: "SortMaps",
x: map[time.Time]string{
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC): "0th birthday",
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC): "1st birthday",
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC): "2nd birthday",
},
y: map[time.Time]string{
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "0th birthday",
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "1st birthday",
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "2nd birthday",
},
opts: []cmp.Option{SortMaps(func(x, y time.Time) int { return time.Time.Compare(x, y) })},
wantEqual: true,
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
}, {
label: "SortMaps",
x: map[MyTime]string{
Expand Down Expand Up @@ -1184,29 +1216,17 @@ func TestPanic(t *testing.T) {
args: args(time.Duration(-1)),
wantPanic: "margin must be a non-negative number",
reason: "negative duration is invalid",
}, {
label: "SortSlices",
fnc: SortSlices,
args: args(strings.Compare),
wantPanic: "invalid less function",
reason: "func(x, y string) int is wrong signature for less",
}, {
label: "SortSlices",
fnc: SortSlices,
args: args((func(_, _ int) bool)(nil)),
wantPanic: "invalid less function",
wantPanic: "invalid less or compare function",
reason: "nil value is not valid",
}, {
label: "SortMaps",
fnc: SortMaps,
args: args(strings.Compare),
wantPanic: "invalid less function",
reason: "func(x, y string) int is wrong signature for less",
}, {
label: "SortMaps",
fnc: SortMaps,
args: args((func(_, _ int) bool)(nil)),
wantPanic: "invalid less function",
wantPanic: "invalid less or compare function",
reason: "nil value is not valid",
}, {
label: "IgnoreFields",
Expand Down
7 changes: 7 additions & 0 deletions cmp/internal/function/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (

tbFunc // func(T) bool
ttbFunc // func(T, T) bool
ttiFunc // func(T, T) int
trbFunc // func(T, R) bool
tibFunc // func(T, I) bool
trFunc // func(T) R
Expand All @@ -28,11 +29,13 @@ const (
Transformer = trFunc // func(T) R
ValueFilter = ttbFunc // func(T, T) bool
Less = ttbFunc // func(T, T) bool
Compare = ttiFunc // func(T, T) int
ValuePredicate = tbFunc // func(T) bool
KeyValuePredicate = trbFunc // func(T, R) bool
)

var boolType = reflect.TypeOf(true)
var intType = reflect.TypeOf(0)

// IsType reports whether the reflect.Type is of the specified function type.
func IsType(t reflect.Type, ft funcType) bool {
Expand All @@ -49,6 +52,10 @@ func IsType(t reflect.Type, ft funcType) bool {
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == boolType {
return true
}
case ttiFunc: // func(T, T) int
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == intType {
return true
}
case trbFunc: // func(T, R) bool
if ni == 2 && no == 1 && t.Out(0) == boolType {
return true
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/google/go-cmp

go 1.13
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow. From 4 years ago!

go 1.21
Loading