diff --git a/CHANGELOG.md b/CHANGELOG.md index 425203d5c95..e2f55f4f5d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ Fixes mainnet bugs w/ incorrect accumulation sumtrees, and CL handling for a bal * [#6071](https://github.com/osmosis-labs/osmosis/pull/6071) reduce number of returns for UpdatePosition and TicksToSqrtPrice functions * [#5906](https://github.com/osmosis-labs/osmosis/pull/5906) Add `AccountLockedCoins` query in lockup module to stargate whitelist. +* [#6053](https://github.com/osmosis-labs/osmosis/pull/6053) monotonic sqrt with 36 decimals ## v17.0.0 diff --git a/osmomath/sqrt.go b/osmomath/sqrt.go index b49779d5039..2f57fba47ae 100644 --- a/osmomath/sqrt.go +++ b/osmomath/sqrt.go @@ -8,7 +8,9 @@ import ( ) var smallestDec = sdk.SmallestDec() +var smallestBigDec = SmallestDec() var tenTo18 = big.NewInt(1e18) +var tenTo36 = big.NewInt(0).Mul(tenTo18, tenTo18) var oneBigInt = big.NewInt(1) // Returns square root of d @@ -49,6 +51,37 @@ func MonotonicSqrt(d sdk.Dec) (sdk.Dec, error) { return root, nil } +func MonotonicSqrtBigDec(d BigDec) (BigDec, error) { + if d.IsNegative() { + return d, errors.New("cannot take square root of negative number") + } + + // A decimal value of d, is represented as an integer of value v = 10^18 * d. + // We have an integer square root function, and we'd like to get the square root of d. + // recall integer square root is floor(sqrt(x)), hence its accurate up to 1 integer. + // we want sqrt d accurate to 18 decimal places. + // So first we multiply our current value by 10^18, then we take the integer square root. + // since sqrt(10^18 * v) = 10^9 * sqrt(v) = 10^18 * sqrt(d), we get the answer we want. + // + // We can than interpret sqrt(10^18 * v) as our resulting decimal and return it. + // monotonicity is guaranteed by correctness of integer square root. + dBi := d.BigInt() + r := big.NewInt(0).Mul(dBi, tenTo36) + r.Sqrt(r) + // However this square root r is s.t. r^2 <= d. We want to flip this to be r^2 >= d. + // To do so, we check that if r^2 < d, do r += 1. Then by correctness we will be in the case we want. + // To compare r^2 and d, we can just compare r^2 and 10^18 * v. (recall r = 10^18 * sqrt(d), v = 10^18 * d) + check := big.NewInt(0).Mul(r, r) + // dBi is a copy of d, so we can modify it. + shiftedD := dBi.Mul(dBi, tenTo36) + if check.Cmp(shiftedD) == -1 { + r.Add(r, oneBigInt) + } + root := NewDecFromBigIntWithPrec(r, 36) + + return root, nil +} + // MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error. func MustMonotonicSqrt(d sdk.Dec) sdk.Dec { sqrt, err := MonotonicSqrt(d) @@ -57,3 +90,12 @@ func MustMonotonicSqrt(d sdk.Dec) sdk.Dec { } return sqrt } + +// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error. +func MustMonotonicSqrtBigDec(d BigDec) BigDec { + sqrt, err := MonotonicSqrtBigDec(d) + if err != nil { + panic(err) + } + return sqrt +} diff --git a/osmomath/sqrt_big_test.go b/osmomath/sqrt_big_test.go new file mode 100644 index 00000000000..83b48d7cdaf --- /dev/null +++ b/osmomath/sqrt_big_test.go @@ -0,0 +1,149 @@ +package osmomath + +import ( + "math/big" + "math/rand" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateRandomDecForEachBitlenBigDec(r *rand.Rand, numPerBitlen int) []BigDec { + return generateRandomDecForEachBitlen[BigDec](r, numPerBitlen, NewDecFromBigIntWithPrec, Precision) +} + +func TestSdkApproxSqrtVectors_BigDec(t *testing.T) { + testCases := []struct { + input BigDec + expected BigDec + }{ + {OneDec(), OneDec()}, // 1.0 => 1.0 + {NewDecWithPrec(25, 2), NewDecWithPrec(5, 1)}, // 0.25 => 0.5 + {NewDecWithPrec(4, 2), NewDecWithPrec(2, 1)}, // 0.09 => 0.3 + {NewDecFromInt(NewInt(9)), NewDecFromInt(NewInt(3))}, // 9 => 3 + {NewDecFromInt(NewInt(2)), MustNewDecFromStr("1.414213562373095048801688724209698079")}, // 2 => 1.414213562373095048801688724209698079 + {smallestBigDec, NewDecWithPrec(1, 18)}, // 10^-36 => 10^-18 + {smallestBigDec.MulInt64(3), NewDecWithPrec(1732050807568877294, 36)}, // 3*10^-36 => sqrt(3)*10^-18 + } + + for i, tc := range testCases { + res, err := MonotonicSqrtBigDec(tc.input) + require.NoError(t, err) + require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) + } +} + +func testMonotonicityAroundBigDec(t *testing.T, x BigDec) { + // test that sqrt(x) is monotonic around x + // i.e. sqrt(x-1) <= sqrt(x) <= sqrt(x+1) + sqrtX, err := MonotonicSqrtBigDec(x) + require.NoError(t, err) + sqrtXMinusOne, err := MonotonicSqrtBigDec(x.Sub(smallestBigDec)) + require.NoError(t, err) + sqrtXPlusOne, err := MonotonicSqrtBigDec(x.Add(smallestBigDec)) + require.NoError(t, err) + assert.True(t, sqrtXMinusOne.LTE(sqrtX), "sqrtXMinusOne: %s, sqrtX: %s", sqrtXMinusOne, sqrtX) + assert.True(t, sqrtX.LTE(sqrtXPlusOne), "sqrtX: %s, sqrtXPlusOne: %s", sqrtX, sqrtXPlusOne) +} + +func TestSqrtMonotinicity_BigDec(t *testing.T) { + type testcase struct { + smaller BigDec + bigger BigDec + } + testCases := []testcase{ + {MustNewDecFromStr("120.120060020005000000"), MustNewDecFromStr("120.120060020005000001")}, + {smallestBigDec, smallestBigDec.MulInt64(2)}, + } + // create random test vectors for every bit-length + r := rand.New(rand.NewSource(rand.Int63())) + for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ { + upperbound := big.NewInt(1) + upperbound.Lsh(upperbound, uint(i)) + for j := 0; j < 100; j++ { + v := big.NewInt(0).Rand(r, upperbound) + d := NewDecFromBigIntWithPrec(v, 36) + testCases = append(testCases, testcase{d, d.Add(smallestBigDec)}) + } + } + for i := 0; i < 1024; i++ { + d := NewDecWithPrec(int64(i), 18) + testCases = append(testCases, testcase{d, d.Add(smallestBigDec)}) + } + + for _, i := range testCases { + sqrtSmaller, err := MonotonicSqrtBigDec(i.smaller) + require.NoError(t, err, "smaller: %s", i.smaller) + sqrtBigger, err := MonotonicSqrtBigDec(i.bigger) + require.NoError(t, err, "bigger: %s", i.bigger) + assert.True(t, sqrtSmaller.LTE(sqrtBigger), "sqrtSmaller: %s, sqrtBigger: %s", sqrtSmaller, sqrtBigger) + + // separately sanity check that sqrt * sqrt >= input + sqrtSmallerSquared := sqrtSmaller.Mul(sqrtSmaller) + assert.True(t, sqrtSmallerSquared.GTE(i.smaller), "sqrt %s, sqrtSmallerSquared: %s, smaller: %s", sqrtSmaller, sqrtSmallerSquared, i.smaller) + } +} + +// Test that square(sqrt(x)) = x when x is a perfect square. +// We do this by sampling sqrt(v) from the set of numbers `a.b`, where a in [0, 2^128], b in [0, 10^9]. +// and then setting x = sqrt(v) +// this is because this is the set of values whose squares are perfectly representable. +func TestPerfectSquares_BigDec(t *testing.T) { + cases := []BigDec{ + NewBigDec(100), + } + r := rand.New(rand.NewSource(rand.Int63())) + tenToMin9 := big.NewInt(1_000_000_000) + for i := 0; i < 128; i++ { + upperbound := big.NewInt(1) + upperbound.Lsh(upperbound, uint(i)) + for j := 0; j < 100; j++ { + v := big.NewInt(0).Rand(r, upperbound) + dec := big.NewInt(0).Rand(r, tenToMin9) + d := NewDecFromBigInt(v).Add(NewDecFromBigIntWithPrec(dec, 9)) + cases = append(cases, d.MulMut(d)) + } + } + + for _, i := range cases { + sqrt, err := MonotonicSqrtBigDec(i) + require.NoError(t, err, "smaller: %s", i) + assert.Equal(t, i, sqrt.MulMut(sqrt)) + if !i.IsZero() { + testMonotonicityAroundBigDec(t, i) + } + } +} + +func TestSqrtRounding_BigDec(t *testing.T) { + testCases := []BigDec{ + MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"), + } + r := rand.New(rand.NewSource(rand.Int63())) + testCases = append(testCases, generateRandomDecForEachBitlenBigDec(r, 10)...) + for _, i := range testCases { + sqrt, err := MonotonicSqrtBigDec(i) + require.NoError(t, err, "smaller: %s", i) + // Sanity check that sqrt * sqrt >= input + sqrtSquared := sqrt.Mul(sqrt) + assert.True(t, sqrtSquared.GTE(i), "sqrt %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i) + // (aside) check that (sqrt - 1ulp)^2 <= input + sqrtMin1 := sqrt.Sub(smallestBigDec) + sqrtSquared = sqrtMin1.Mul(sqrtMin1) + assert.True(t, sqrtSquared.LTE(i), "sqrtMin1ULP %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i) + } +} + +// benchmarks the new square root across bit-lengths, for comparison with the SDK square root. +func BenchmarkMonotonicSqrt_BigDec(b *testing.B) { + r := rand.New(rand.NewSource(1)) + vectors := generateRandomDecForEachBitlenBigDec(r, 1) + for i := 0; i < b.N; i++ { + for j := 0; j < len(vectors); j++ { + a, _ := MonotonicSqrtBigDec(vectors[j]) + _ = a + } + } +} diff --git a/osmomath/sqrt_test.go b/osmomath/sqrt_test.go index 22c3ccbfd6f..9dfd9e4b194 100644 --- a/osmomath/sqrt_test.go +++ b/osmomath/sqrt_test.go @@ -10,14 +10,18 @@ import ( "github.com/stretchr/testify/require" ) -func generateRandomDecForEachBitlen(r *rand.Rand, numPerBitlen int) []sdk.Dec { - res := make([]sdk.Dec, (255+sdk.DecimalPrecisionBits)*numPerBitlen) +func generateRandomDecForEachBitlenDec(r *rand.Rand, numPerBitlen int) []sdk.Dec { + return generateRandomDecForEachBitlen[sdk.Dec](r, numPerBitlen, sdk.NewDecFromBigIntWithPrec, sdk.Precision) +} + +func generateRandomDecForEachBitlen[T any](r *rand.Rand, numPerBitlen int, constructor func(*big.Int, int64) T, precision int64) []T { + res := make([]T, (255+sdk.DecimalPrecisionBits)*numPerBitlen) for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ { upperbound := big.NewInt(1) upperbound.Lsh(upperbound, uint(i)) for j := 0; j < numPerBitlen; j++ { v := big.NewInt(0).Rand(r, upperbound) - res[i*numPerBitlen+j] = sdk.NewDecFromBigIntWithPrec(v, 18) + res[i*numPerBitlen+j] = constructor(v, precision) } } return res @@ -133,7 +137,7 @@ func TestSqrtRounding(t *testing.T) { // sdk.MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"), } r := rand.New(rand.NewSource(rand.Int63())) - testCases = append(testCases, generateRandomDecForEachBitlen(r, 10)...) + testCases = append(testCases, generateRandomDecForEachBitlenDec(r, 10)...) for _, i := range testCases { sqrt, err := MonotonicSqrt(i) require.NoError(t, err, "smaller: %s", i) @@ -150,7 +154,7 @@ func TestSqrtRounding(t *testing.T) { // benchmarks the SDK square root across bit-lengths, for comparison with the new square root. func BenchmarkSqrt(b *testing.B) { r := rand.New(rand.NewSource(1)) - vectors := generateRandomDecForEachBitlen(r, 1) + vectors := generateRandomDecForEachBitlenDec(r, 1) for i := 0; i < b.N; i++ { for j := 0; j < len(vectors); j++ { a, _ := vectors[j].ApproxSqrt() @@ -162,7 +166,7 @@ func BenchmarkSqrt(b *testing.B) { // benchmarks the new square root across bit-lengths, for comparison with the SDK square root. func BenchmarkMonotonicSqrt(b *testing.B) { r := rand.New(rand.NewSource(1)) - vectors := generateRandomDecForEachBitlen(r, 1) + vectors := generateRandomDecForEachBitlenDec(r, 1) for i := 0; i < b.N; i++ { for j := 0; j < len(vectors); j++ { a, _ := MonotonicSqrt(vectors[j])