Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: monotonic sqrt big dec #6053

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* [#6071](https://github.com/osmosis-labs/osmosis/pull/6071) reduce number of returns for UpdatePosition and TicksToSqrtPrice functions
* [#5906](https://github.com/osmosis-labs/osmosis/pull/5906) Add `AccountLockedCoins` query in lockup module to stargate whitelist.
* [#6053](https://github.com/osmosis-labs/osmosis/pull/6053) monotonic sqrt with 36 decimals

## v17.0.0

Expand Down
42 changes: 42 additions & 0 deletions osmomath/sqrt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
)

var smallestDec = sdk.SmallestDec()
var smallestBigDec = SmallestDec()
var tenTo18 = big.NewInt(1e18)
var tenTo36 = big.NewInt(0).Mul(tenTo18, tenTo18)
var oneBigInt = big.NewInt(1)

// Returns square root of d
Expand Down Expand Up @@ -49,6 +51,37 @@ func MonotonicSqrt(d sdk.Dec) (sdk.Dec, error) {
return root, nil
}

func MonotonicSqrtBigDec(d BigDec) (BigDec, error) {
if d.IsNegative() {
return d, errors.New("cannot take square root of negative number")
}

// A decimal value of d, is represented as an integer of value v = 10^18 * d.
// We have an integer square root function, and we'd like to get the square root of d.
// recall integer square root is floor(sqrt(x)), hence its accurate up to 1 integer.
// we want sqrt d accurate to 18 decimal places.
// So first we multiply our current value by 10^18, then we take the integer square root.
// since sqrt(10^18 * v) = 10^9 * sqrt(v) = 10^18 * sqrt(d), we get the answer we want.
//
// We can than interpret sqrt(10^18 * v) as our resulting decimal and return it.
// monotonicity is guaranteed by correctness of integer square root.
dBi := d.BigInt()
r := big.NewInt(0).Mul(dBi, tenTo36)
r.Sqrt(r)
// However this square root r is s.t. r^2 <= d. We want to flip this to be r^2 >= d.
// To do so, we check that if r^2 < d, do r += 1. Then by correctness we will be in the case we want.
// To compare r^2 and d, we can just compare r^2 and 10^18 * v. (recall r = 10^18 * sqrt(d), v = 10^18 * d)
check := big.NewInt(0).Mul(r, r)
// dBi is a copy of d, so we can modify it.
shiftedD := dBi.Mul(dBi, tenTo36)
if check.Cmp(shiftedD) == -1 {
r.Add(r, oneBigInt)
}
root := NewDecFromBigIntWithPrec(r, 36)

return root, nil
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
sqrt, err := MonotonicSqrt(d)
Expand All @@ -57,3 +90,12 @@ func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
}
return sqrt
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrtBigDec(d BigDec) BigDec {
sqrt, err := MonotonicSqrtBigDec(d)
if err != nil {
panic(err)
}
return sqrt
}
149 changes: 149 additions & 0 deletions osmomath/sqrt_big_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package osmomath

import (
"math/big"
"math/rand"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlenBigDec(r *rand.Rand, numPerBitlen int) []BigDec {
return generateRandomDecForEachBitlen[BigDec](r, numPerBitlen, NewDecFromBigIntWithPrec, Precision)
}

func TestSdkApproxSqrtVectors_BigDec(t *testing.T) {
testCases := []struct {
input BigDec
expected BigDec
}{
{OneDec(), OneDec()}, // 1.0 => 1.0
{NewDecWithPrec(25, 2), NewDecWithPrec(5, 1)}, // 0.25 => 0.5
{NewDecWithPrec(4, 2), NewDecWithPrec(2, 1)}, // 0.09 => 0.3
{NewDecFromInt(NewInt(9)), NewDecFromInt(NewInt(3))}, // 9 => 3
{NewDecFromInt(NewInt(2)), MustNewDecFromStr("1.414213562373095048801688724209698079")}, // 2 => 1.414213562373095048801688724209698079
{smallestBigDec, NewDecWithPrec(1, 18)}, // 10^-36 => 10^-18
{smallestBigDec.MulInt64(3), NewDecWithPrec(1732050807568877294, 36)}, // 3*10^-36 => sqrt(3)*10^-18
}

for i, tc := range testCases {
res, err := MonotonicSqrtBigDec(tc.input)
require.NoError(t, err)
require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input)
}
}

func testMonotonicityAroundBigDec(t *testing.T, x BigDec) {
// test that sqrt(x) is monotonic around x
// i.e. sqrt(x-1) <= sqrt(x) <= sqrt(x+1)
sqrtX, err := MonotonicSqrtBigDec(x)
require.NoError(t, err)
sqrtXMinusOne, err := MonotonicSqrtBigDec(x.Sub(smallestBigDec))
require.NoError(t, err)
sqrtXPlusOne, err := MonotonicSqrtBigDec(x.Add(smallestBigDec))
require.NoError(t, err)
assert.True(t, sqrtXMinusOne.LTE(sqrtX), "sqrtXMinusOne: %s, sqrtX: %s", sqrtXMinusOne, sqrtX)
assert.True(t, sqrtX.LTE(sqrtXPlusOne), "sqrtX: %s, sqrtXPlusOne: %s", sqrtX, sqrtXPlusOne)
}

func TestSqrtMonotinicity_BigDec(t *testing.T) {
type testcase struct {
smaller BigDec
bigger BigDec
}
testCases := []testcase{
{MustNewDecFromStr("120.120060020005000000"), MustNewDecFromStr("120.120060020005000001")},
{smallestBigDec, smallestBigDec.MulInt64(2)},
}
// create random test vectors for every bit-length
r := rand.New(rand.NewSource(rand.Int63()))
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
d := NewDecFromBigIntWithPrec(v, 36)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}
}
for i := 0; i < 1024; i++ {
d := NewDecWithPrec(int64(i), 18)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}

for _, i := range testCases {
sqrtSmaller, err := MonotonicSqrtBigDec(i.smaller)
require.NoError(t, err, "smaller: %s", i.smaller)
sqrtBigger, err := MonotonicSqrtBigDec(i.bigger)
require.NoError(t, err, "bigger: %s", i.bigger)
assert.True(t, sqrtSmaller.LTE(sqrtBigger), "sqrtSmaller: %s, sqrtBigger: %s", sqrtSmaller, sqrtBigger)

// separately sanity check that sqrt * sqrt >= input
sqrtSmallerSquared := sqrtSmaller.Mul(sqrtSmaller)
assert.True(t, sqrtSmallerSquared.GTE(i.smaller), "sqrt %s, sqrtSmallerSquared: %s, smaller: %s", sqrtSmaller, sqrtSmallerSquared, i.smaller)
}
}

// Test that square(sqrt(x)) = x when x is a perfect square.
// We do this by sampling sqrt(v) from the set of numbers `a.b`, where a in [0, 2^128], b in [0, 10^9].
// and then setting x = sqrt(v)
// this is because this is the set of values whose squares are perfectly representable.
func TestPerfectSquares_BigDec(t *testing.T) {
cases := []BigDec{
NewBigDec(100),
}
r := rand.New(rand.NewSource(rand.Int63()))
tenToMin9 := big.NewInt(1_000_000_000)
for i := 0; i < 128; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
dec := big.NewInt(0).Rand(r, tenToMin9)
d := NewDecFromBigInt(v).Add(NewDecFromBigIntWithPrec(dec, 9))
cases = append(cases, d.MulMut(d))
}
}

for _, i := range cases {
sqrt, err := MonotonicSqrtBigDec(i)
require.NoError(t, err, "smaller: %s", i)
assert.Equal(t, i, sqrt.MulMut(sqrt))
if !i.IsZero() {
testMonotonicityAroundBigDec(t, i)
}
}
}

func TestSqrtRounding_BigDec(t *testing.T) {
testCases := []BigDec{
MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlenBigDec(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrtBigDec(i)
require.NoError(t, err, "smaller: %s", i)
// Sanity check that sqrt * sqrt >= input
sqrtSquared := sqrt.Mul(sqrt)
assert.True(t, sqrtSquared.GTE(i), "sqrt %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
// (aside) check that (sqrt - 1ulp)^2 <= input
sqrtMin1 := sqrt.Sub(smallestBigDec)
sqrtSquared = sqrtMin1.Mul(sqrtMin1)
assert.True(t, sqrtSquared.LTE(i), "sqrtMin1ULP %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
}
}

// benchmarks the new square root across bit-lengths, for comparison with the SDK square root.
func BenchmarkMonotonicSqrt_BigDec(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlenBigDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrtBigDec(vectors[j])
_ = a
}
}
}
16 changes: 10 additions & 6 deletions osmomath/sqrt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ import (
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlen(r *rand.Rand, numPerBitlen int) []sdk.Dec {
res := make([]sdk.Dec, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
func generateRandomDecForEachBitlenDec(r *rand.Rand, numPerBitlen int) []sdk.Dec {
return generateRandomDecForEachBitlen[sdk.Dec](r, numPerBitlen, sdk.NewDecFromBigIntWithPrec, sdk.Precision)
}

func generateRandomDecForEachBitlen[T any](r *rand.Rand, numPerBitlen int, constructor func(*big.Int, int64) T, precision int64) []T {
res := make([]T, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < numPerBitlen; j++ {
v := big.NewInt(0).Rand(r, upperbound)
res[i*numPerBitlen+j] = sdk.NewDecFromBigIntWithPrec(v, 18)
res[i*numPerBitlen+j] = constructor(v, precision)
}
}
return res
Expand Down Expand Up @@ -133,7 +137,7 @@ func TestSqrtRounding(t *testing.T) {
// sdk.MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlen(r, 10)...)
testCases = append(testCases, generateRandomDecForEachBitlenDec(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
Expand All @@ -150,7 +154,7 @@ func TestSqrtRounding(t *testing.T) {
// benchmarks the SDK square root across bit-lengths, for comparison with the new square root.
func BenchmarkSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
vectors := generateRandomDecForEachBitlenDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := vectors[j].ApproxSqrt()
Expand All @@ -162,7 +166,7 @@ func BenchmarkSqrt(b *testing.B) {
// benchmarks the new square root across bit-lengths, for comparison with the SDK square root.
func BenchmarkMonotonicSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
vectors := generateRandomDecForEachBitlenDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrt(vectors[j])
Expand Down