Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monotonic square root #5543

Merged
merged 10 commits into from
Jun 17, 2023
50 changes: 50 additions & 0 deletions osmomath/sqrt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package osmomath

import (
"errors"
"math/big"

sdk "github.com/cosmos/cosmos-sdk/types"
)

var smallestDec = sdk.SmallestDec()
var tenTo18 = big.NewInt(1e18)
var oneBigInt = big.NewInt(1)

// Returns square root of d
// returns an error iff one of the following conditions is met:
// - d is negative
// - d is too small to have a representable square root.
// This function guarantees:
// the returned root r, will be such that r^2 >= d
// This function is monotonic, i.e. if d1 >= d2, then sqrt(d1) >= sqrt(d2)
func MonotonicSqrt(d sdk.Dec) (sdk.Dec, error) {
if d.IsNegative() {
return d, errors.New("cannot take square root of negative number")
}

// A decimal value of d, is represented as an integer of value v = 10^18 * d.
// We have an integer square root function, and we'd like to get the square root of d.
// recall integer square root is floor(sqrt(x)), hence its accuract up to 1 integer.
ValarDragon marked this conversation as resolved.
Show resolved Hide resolved
// we want sqrt d accurate to 18 decimal places.
// So first we multiply our current value by 10^18, then we take the integer square root.
// since sqrt(10^18 * v) = 10^9 * sqrt(v) = 10^18 * sqrt(d), we get the answer we want.
//
// We can than interpret sqrt(10^18 * v) as our resulting decimal and return it.
// monotonicity is guaranteed by correctness of integer square root.
dBi := d.BigInt()
r := big.NewInt(0).Mul(dBi, tenTo18)
r.Sqrt(r)
// However this square root r is s.t. r^2 <= d. We want to flip this to be r^2 >= d.
// To do so, we check that if r^2 < d, do r += 1. Then by correctness we will be in the case we want.
// To compare r^2 and d, we can just compare r^2 and 10^18 * v. (recall r = 10^18 * sqrt(d), v = 10^18 * d)
check := big.NewInt(0).Mul(r, r)
// dBi is a copy of d, so we can modify it.
shiftedD := dBi.Mul(dBi, tenTo18)
if check.Cmp(shiftedD) == -1 {
r.Add(r, oneBigInt)
}
root := sdk.NewDecFromBigIntWithPrec(r, 18)

return root, nil
}
168 changes: 168 additions & 0 deletions osmomath/sqrt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package osmomath

import (
"math/big"
"math/rand"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlen(r *rand.Rand, numPerBitlen int) []sdk.Dec {
res := make([]sdk.Dec, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < numPerBitlen; j++ {
v := big.NewInt(0).Rand(r, upperbound)
res[i*numPerBitlen+j] = sdk.NewDecFromBigIntWithPrec(v, 18)
}
}
return res
}

func TestSdkApproxSqrtVectors(t *testing.T) {
testCases := []struct {
input sdk.Dec
expected sdk.Dec
}{
{sdk.OneDec(), sdk.OneDec()}, // 1.0 => 1.0
{sdk.NewDecWithPrec(25, 2), sdk.NewDecWithPrec(5, 1)}, // 0.25 => 0.5
{sdk.NewDecWithPrec(4, 2), sdk.NewDecWithPrec(2, 1)}, // 0.09 => 0.3
{sdk.NewDecFromInt(sdk.NewInt(9)), sdk.NewDecFromInt(sdk.NewInt(3))}, // 9 => 3
{sdk.NewDecFromInt(sdk.NewInt(2)), sdk.NewDecWithPrec(1414213562373095049, 18)}, // 2 => 1.414213562373095049
{smallestDec, sdk.NewDecWithPrec(1, 9)}, // 10^-18 => 10^-9
{smallestDec.MulInt64(3), sdk.NewDecWithPrec(1732050808, 18)}, // 3*10^-18 => sqrt(3)*10^-9
}

for i, tc := range testCases {
res, err := MonotonicSqrt(tc.input)
require.NoError(t, err)
require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input)
}
}

func testMonotonicityAround(t *testing.T, x sdk.Dec) {
// test that sqrt(x) is monotonic around x
// i.e. sqrt(x-1) <= sqrt(x) <= sqrt(x+1)
sqrtX, err := MonotonicSqrt(x)
require.NoError(t, err)
sqrtXMinusOne, err := MonotonicSqrt(x.Sub(smallestDec))
require.NoError(t, err)
sqrtXPlusOne, err := MonotonicSqrt(x.Add(smallestDec))
require.NoError(t, err)
assert.True(t, sqrtXMinusOne.LTE(sqrtX), "sqrtXMinusOne: %s, sqrtX: %s", sqrtXMinusOne, sqrtX)
assert.True(t, sqrtX.LTE(sqrtXPlusOne), "sqrtX: %s, sqrtXPlusOne: %s", sqrtX, sqrtXPlusOne)
}

func TestSqrtMonotinicity(t *testing.T) {
type testcase struct {
smaller sdk.Dec
bigger sdk.Dec
}
testCases := []testcase{
{sdk.MustNewDecFromStr("120.120060020005000000"), sdk.MustNewDecFromStr("120.120060020005000001")},
{smallestDec, smallestDec.MulInt64(2)},
}
// create random test vectors for every bit-length
r := rand.New(rand.NewSource(rand.Int63()))
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
d := sdk.NewDecFromBigIntWithPrec(v, 18)
testCases = append(testCases, testcase{d, d.Add(smallestDec)})
}
}
for i := 0; i < 1024; i++ {
d := sdk.NewDecWithPrec(int64(i), 18)
testCases = append(testCases, testcase{d, d.Add(smallestDec)})
}

for _, i := range testCases {
sqrtSmaller, err := MonotonicSqrt(i.smaller)
require.NoError(t, err, "smaller: %s", i.smaller)
sqrtBigger, err := MonotonicSqrt(i.bigger)
require.NoError(t, err, "bigger: %s", i.bigger)
assert.True(t, sqrtSmaller.LTE(sqrtBigger), "sqrtSmaller: %s, sqrtBigger: %s", sqrtSmaller, sqrtBigger)

// separately sanity check that sqrt * sqrt >= input
sqrtSmallerSquared := sqrtSmaller.Mul(sqrtSmaller)
assert.True(t, sqrtSmallerSquared.GTE(i.smaller), "sqrt %s, sqrtSmallerSquared: %s, smaller: %s", sqrtSmaller, sqrtSmallerSquared, i.smaller)
}
}

// Test that square(sqrt(x)) = x when x is a perfect square.
// We do this by sampling sqrt(v) from the set of numbers `a.b`, where a in [0, 2^128], b in [0, 10^9].
// and then setting x = sqrt(v)
// this is because this is the set of values whose squares are perfectly representable.
func TestPerfectSquares(t *testing.T) {
cases := []sdk.Dec{
sdk.NewDec(100),
}
r := rand.New(rand.NewSource(rand.Int63()))
tenToMin9 := big.NewInt(1_000_000_000)
for i := 0; i < 128; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
dec := big.NewInt(0).Rand(r, tenToMin9)
d := sdk.NewDecFromBigInt(v).Add(sdk.NewDecFromBigIntWithPrec(dec, 9))
cases = append(cases, d.MulMut(d))
}
}

for _, i := range cases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
assert.Equal(t, i, sqrt.MulMut(sqrt))
if !i.IsZero() {
testMonotonicityAround(t, i)
}
}
}

func TestSqrtRounding(t *testing.T) {
testCases := []sdk.Dec{
sdk.MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlen(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
// Sanity check that sqrt * sqrt >= input
sqrtSquared := sqrt.Mul(sqrt)
assert.True(t, sqrtSquared.GTE(i), "sqrt %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
// (aside) check that (sqrt - 1ulp)^2 <= input
sqrtMin1 := sqrt.Sub(smallestDec)
sqrtSquared = sqrtMin1.Mul(sqrtMin1)
assert.True(t, sqrtSquared.LTE(i), "sqrtMin1ULP %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
}
}

func BenchmarkSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := vectors[j].ApproxSqrt()
_ = a
}
}
}

func BenchmarkMonotonicSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrt(vectors[j])
_ = a
}
}
}
ValarDragon marked this conversation as resolved.
Show resolved Hide resolved