diff --git a/dot/state/slot.go b/dot/state/slot.go index f926edbf03..267a8ada56 100644 --- a/dot/state/slot.go +++ b/dot/state/slot.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/database" + "github.com/ChainSafe/gossamer/lib/primitives" "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -45,10 +46,10 @@ type headerAndSigner struct { } func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header, - signer types.AuthorityID) (*types.BabeEquivocationProof, error) { + signer types.AuthorityID) (*types.BabeEquivocationProof, error) { //skipcq: GO-R1005 // We don't check equivocations for old headers out of our capacity. // checking slotNow is greater than slot to avoid overflow, same as saturating_sub - if saturatingSub(slotNow, slot) > maxSlotCapacity { + if primitives.SaturatingSub(slotNow, slot) > maxSlotCapacity { return nil, nil } @@ -127,7 +128,7 @@ func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header newFirstSavedSlot := firstSavedSlot if slotNow-firstSavedSlot >= pruningBound { - newFirstSavedSlot = saturatingSub(slotNow, maxSlotCapacity) + newFirstSavedSlot = primitives.SaturatingSub(slotNow, maxSlotCapacity) for s := firstSavedSlot; s < newFirstSavedSlot; s++ { slotEncoded := make([]byte, 8) @@ -184,10 +185,3 @@ func (s *SlotState) CheckEquivocation(slotNow, slot uint64, header *types.Header return nil, nil } - -func saturatingSub(a, b uint64) uint64 { - if a > b { - return a - b - } - return 0 -} diff --git a/lib/primitives/math.go b/lib/primitives/math.go new file mode 100644 index 0000000000..d44a0c1293 --- /dev/null +++ b/lib/primitives/math.go @@ -0,0 +1,89 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package primitives + +import ( + "fmt" + "unsafe" + + "golang.org/x/exp/constraints" +) + +// saturatingOperations applies the correct operation +// given the input types +func saturatingOperations[T constraints.Integer](a, b T, + signedSaturatingOperation func(T, T, T, T) T, + unsignedSaturatingOperation func(T, T) T, +) T { + switch any(a).(type) { + case int, int8, int16, int32, int64: + // #nosec G103 + sizeOf := (unsafe.Sizeof(a) * 8) - 1 + + var ( + maxValueOfSignedType T = 1< 0 && a > max-b { + return max + } + + if b < 0 && a < min-b { + return min + } + + return a + b +} + +func saturatingAddUnsigned[T constraints.Integer](a, b T) T { + // the operation ^T(0) gives us the max value of type T + // eg. if T is uint8 then it gives us 255 + max := ^T(0) + + if a > max-b { + return max + } + return a + b +} + +// SaturatingSub computes a - b saturating at the numeric bounds instead of overflowing +func SaturatingSub[T constraints.Integer](a, b T) T { + return saturatingOperations(a, b, saturatingSubSigned, saturatingSubUnsigned) +} + +func saturatingSubSigned[T constraints.Integer](a, b, max, min T) T { + if b < 0 && a > max+b { + return max + } + + if b > 0 && a < min+b { + return min + } + + return a - b +} + +func saturatingSubUnsigned[T constraints.Integer](a, b T) T { + if a > b { + return a - b + } + return 0 +} diff --git a/lib/primitives/math_test.go b/lib/primitives/math_test.go new file mode 100644 index 0000000000..52cfb78db2 --- /dev/null +++ b/lib/primitives/math_test.go @@ -0,0 +1,37 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package primitives + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common/math" + "github.com/stretchr/testify/require" +) + +func TestSaturatingAdd(t *testing.T) { + require.Equal(t, uint8(2), SaturatingAdd(uint8(1), uint8(1))) + require.Equal(t, uint8(math.MaxUint8), SaturatingAdd(uint8(math.MaxUint8), 100)) + + require.Equal(t, uint32(math.MaxUint32), SaturatingAdd(uint32(math.MaxUint32), 100)) + require.Equal(t, uint32(100), SaturatingAdd(uint32(0), 100)) + + // should not be able to overflow in the opposite direction as well + require.Equal(t, int64(math.MinInt64), SaturatingAdd(int64(math.MinInt64), -100)) + require.Equal(t, int8(127), SaturatingAdd(int8(120), 7)) + require.Equal(t, int8(127), SaturatingAdd(int8(120), 8)) +} + +func TestSaturatingSub(t *testing.T) { + // -128 - 100 overflows, so it should return just -128 + require.Equal(t, int8(math.MinInt8), SaturatingSub(int8(math.MinInt8), 100)) + require.Equal(t, int8(0), SaturatingSub(int8(100), 100)) + + // max - (-1) = max + 1 = overflows, so it should return just max + require.Equal(t, int64(math.MaxInt64), SaturatingSub(int64(math.MaxInt64), -1)) + + // 2 - 10 = -8 which overflows, then should return just 0 + require.Equal(t, uint32(0), SaturatingSub(uint32(2), uint32(10))) + require.Equal(t, uint64(math.MaxUint64), SaturatingSub(uint64(math.MaxUint64), uint64(0))) +}