Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: monotonicity maintained with high precision #6020

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions osmomath/sqrt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
)

var smallestDec = sdk.SmallestDec()
var smallestBigDec = SmallestDec()
var tenTo18 = big.NewInt(1e18)
var tenTo36 = big.NewInt(0).Mul(tenTo18, tenTo18)
var oneBigInt = big.NewInt(1)

// Returns square root of d
Expand Down Expand Up @@ -49,6 +51,37 @@ func MonotonicSqrt(d sdk.Dec) (sdk.Dec, error) {
return root, nil
}

func MonotonicSqrtBigDec(d BigDec) (BigDec, error) {
if d.IsNegative() {
return d, errors.New("cannot take square root of negative number")
}

// A decimal value of d, is represented as an integer of value v = 10^18 * d.
// We have an integer square root function, and we'd like to get the square root of d.
// recall integer square root is floor(sqrt(x)), hence its accurate up to 1 integer.
// we want sqrt d accurate to 18 decimal places.
// So first we multiply our current value by 10^18, then we take the integer square root.
// since sqrt(10^18 * v) = 10^9 * sqrt(v) = 10^18 * sqrt(d), we get the answer we want.
//
// We can than interpret sqrt(10^18 * v) as our resulting decimal and return it.
// monotonicity is guaranteed by correctness of integer square root.
dBi := d.BigInt()
r := big.NewInt(0).Mul(dBi, tenTo36)
r.Sqrt(r)
// However this square root r is s.t. r^2 <= d. We want to flip this to be r^2 >= d.
// To do so, we check that if r^2 < d, do r += 1. Then by correctness we will be in the case we want.
// To compare r^2 and d, we can just compare r^2 and 10^18 * v. (recall r = 10^18 * sqrt(d), v = 10^18 * d)
check := big.NewInt(0).Mul(r, r)
// dBi is a copy of d, so we can modify it.
shiftedD := dBi.Mul(dBi, tenTo36)
if check.Cmp(shiftedD) == -1 {
r.Add(r, oneBigInt)
}
root := NewDecFromBigIntWithPrec(r, 36)

return root, nil
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
sqrt, err := MonotonicSqrt(d)
Expand All @@ -57,3 +90,12 @@ func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
}
return sqrt
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrtBigDec(d BigDec) BigDec {
sqrt, err := MonotonicSqrtBigDec(d)
if err != nil {
panic(err)
}
return sqrt
}
160 changes: 160 additions & 0 deletions osmomath/sqrt_big_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package osmomath

import (
"math/big"
"math/rand"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlenBigDec(r *rand.Rand, numPerBitlen int) []BigDec {
res := make([]BigDec, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < numPerBitlen; j++ {
v := big.NewInt(0).Rand(r, upperbound)
res[i*numPerBitlen+j] = NewDecFromBigIntWithPrec(v, 36)
}
}
return res
}

func TestSdkApproxSqrtVectors_BigDec(t *testing.T) {
testCases := []struct {
input BigDec
expected BigDec
}{
{OneDec(), OneDec()}, // 1.0 => 1.0
{NewDecWithPrec(25, 2), NewDecWithPrec(5, 1)}, // 0.25 => 0.5
{NewDecWithPrec(4, 2), NewDecWithPrec(2, 1)}, // 0.09 => 0.3
{NewDecFromInt(NewInt(9)), NewDecFromInt(NewInt(3))}, // 9 => 3
{NewDecFromInt(NewInt(2)), MustNewDecFromStr("1.414213562373095048801688724209698079")}, // 2 => 1.414213562373095048801688724209698079
{smallestBigDec, NewDecWithPrec(1, 18)}, // 10^-36 => 10^-18
{smallestBigDec.MulInt64(3), NewDecWithPrec(1732050807568877294, 36)}, // 3*10^-36 => sqrt(3)*10^-18
}

for i, tc := range testCases {
res, err := MonotonicSqrtBigDec(tc.input)
require.NoError(t, err)
require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input)
}
}

func testMonotonicityAroundBigDec(t *testing.T, x BigDec) {
// test that sqrt(x) is monotonic around x
// i.e. sqrt(x-1) <= sqrt(x) <= sqrt(x+1)
sqrtX, err := MonotonicSqrtBigDec(x)
require.NoError(t, err)
sqrtXMinusOne, err := MonotonicSqrtBigDec(x.Sub(smallestBigDec))
require.NoError(t, err)
sqrtXPlusOne, err := MonotonicSqrtBigDec(x.Add(smallestBigDec))
require.NoError(t, err)
assert.True(t, sqrtXMinusOne.LTE(sqrtX), "sqrtXMinusOne: %s, sqrtX: %s", sqrtXMinusOne, sqrtX)
assert.True(t, sqrtX.LTE(sqrtXPlusOne), "sqrtX: %s, sqrtXPlusOne: %s", sqrtX, sqrtXPlusOne)
}

func TestSqrtMonotinicity_BigDec(t *testing.T) {
type testcase struct {
smaller BigDec
bigger BigDec
}
testCases := []testcase{
{MustNewDecFromStr("120.120060020005000000"), MustNewDecFromStr("120.120060020005000001")},
{smallestBigDec, smallestBigDec.MulInt64(2)},
}
// create random test vectors for every bit-length
r := rand.New(rand.NewSource(rand.Int63()))
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
d := NewDecFromBigIntWithPrec(v, 36)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}
}
for i := 0; i < 1024; i++ {
d := NewDecWithPrec(int64(i), 18)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}

for _, i := range testCases {
sqrtSmaller, err := MonotonicSqrtBigDec(i.smaller)
require.NoError(t, err, "smaller: %s", i.smaller)
sqrtBigger, err := MonotonicSqrtBigDec(i.bigger)
require.NoError(t, err, "bigger: %s", i.bigger)
assert.True(t, sqrtSmaller.LTE(sqrtBigger), "sqrtSmaller: %s, sqrtBigger: %s", sqrtSmaller, sqrtBigger)

// separately sanity check that sqrt * sqrt >= input
sqrtSmallerSquared := sqrtSmaller.Mul(sqrtSmaller)
assert.True(t, sqrtSmallerSquared.GTE(i.smaller), "sqrt %s, sqrtSmallerSquared: %s, smaller: %s", sqrtSmaller, sqrtSmallerSquared, i.smaller)
}
}

// Test that square(sqrt(x)) = x when x is a perfect square.
// We do this by sampling sqrt(v) from the set of numbers `a.b`, where a in [0, 2^128], b in [0, 10^9].
// and then setting x = sqrt(v)
// this is because this is the set of values whose squares are perfectly representable.
func TestPerfectSquares_BigDec(t *testing.T) {
cases := []sdk.Dec{
sdk.NewDec(100),
}
r := rand.New(rand.NewSource(rand.Int63()))
tenToMin9 := big.NewInt(1_000_000_000)
for i := 0; i < 128; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
dec := big.NewInt(0).Rand(r, tenToMin9)
d := sdk.NewDecFromBigInt(v).Add(sdk.NewDecFromBigIntWithPrec(dec, 9))
cases = append(cases, d.MulMut(d))
}
}

for _, i := range cases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
assert.Equal(t, i, sqrt.MulMut(sqrt))
if !i.IsZero() {
testMonotonicityAround(t, i)
}
}
}

func TestSqrtRounding_BigDec(t *testing.T) {
testCases := []sdk.Dec{
// TODO: uncomment when SDK supports dec from str with bigger bitlenghths.
// it works if you override the sdk panic locally.
// sdk.MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlen(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
// Sanity check that sqrt * sqrt >= input
sqrtSquared := sqrt.Mul(sqrt)
assert.True(t, sqrtSquared.GTE(i), "sqrt %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
// (aside) check that (sqrt - 1ulp)^2 <= input
sqrtMin1 := sqrt.Sub(smallestDec)
sqrtSquared = sqrtMin1.Mul(sqrtMin1)
assert.True(t, sqrtSquared.LTE(i), "sqrtMin1ULP %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
}
}

// benchmarks the new square root across bit-lengths, for comparison with the SDK square root.
func BenchmarkMonotonicSqrt_BigDec(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrt(vectors[j])
_ = a
}
}
}
20 changes: 10 additions & 10 deletions x/concentrated-liquidity/math/precompute.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ var (
// -1 => (0.1, 10^(types.ExponentAtPriceOne - 1), 9 * (types.ExponentAtPriceOne - 1))
type tickExpIndexData struct {
// if price < initialPrice, we are not in this exponent range.
initialPrice sdk.Dec
initialPrice osmomath.BigDec
// if price >= maxPrice, we are not in this exponent range.
maxPrice sdk.Dec
maxPrice osmomath.BigDec
// TODO: Change to normal Dec, if min spot price increases.
// additive increment per tick here.
additiveIncrementPerTick osmomath.BigDec
Expand All @@ -50,27 +50,27 @@ var tickExpCache map[int64]*tickExpIndexData = make(map[int64]*tickExpIndexData)

func buildTickExpCache() {
// build positive indices first
maxPrice := sdkOneDec
maxPrice := osmomathBigOneDec
curExpIndex := int64(0)
for maxPrice.LT(types.MaxSpotPrice) {
for maxPrice.LT(osmomath.BigDecFromSDKDec(types.MaxSpotPrice)) {
tickExpCache[curExpIndex] = &tickExpIndexData{
// price range 10^curExpIndex to 10^(curExpIndex + 1). (10, 100)
initialPrice: sdkTenDec.Power(uint64(curExpIndex)),
maxPrice: sdkTenDec.Power(uint64(curExpIndex + 1)),
initialPrice: osmomathBigTenDec.PowerInteger(uint64(curExpIndex)),
maxPrice: osmomathBigTenDec.PowerInteger(uint64(curExpIndex + 1)),
additiveIncrementPerTick: powTenBigDec(types.ExponentAtPriceOne + curExpIndex),
initialTick: geometricExponentIncrementDistanceInTicks * curExpIndex,
}
maxPrice = tickExpCache[curExpIndex].maxPrice
curExpIndex += 1
}

minPrice := sdkOneDec
minPrice := osmomathBigOneDec
curExpIndex = -1
for minPrice.GT(types.MinSpotPrice) {
for minPrice.GT(osmomath.NewDecWithPrec(1, 30)) {
tickExpCache[curExpIndex] = &tickExpIndexData{
// price range 10^curExpIndex to 10^(curExpIndex + 1). (0.001, 0.01)
initialPrice: powTenBigDec(curExpIndex).SDKDec(),
maxPrice: powTenBigDec(curExpIndex + 1).SDKDec(),
initialPrice: powTenBigDec(curExpIndex),
maxPrice: powTenBigDec(curExpIndex + 1),
additiveIncrementPerTick: powTenBigDec(types.ExponentAtPriceOne + curExpIndex),
initialTick: geometricExponentIncrementDistanceInTicks * curExpIndex,
}
Expand Down
Loading