Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gamm stableswap improvements #3839

Merged
merged 6 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions osmomath/binary_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(sdk.MinInt(expected.Abs(), actual.Abs()).ToDec())
minValue := sdk.MinInt(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue.ToDec())
if errTerm.GT(e.MultiplicativeTolerance) {
return comparisonSign
}
Expand Down Expand Up @@ -121,7 +126,12 @@ func (e ErrTolerance) CompareBigDec(expected BigDec, actual BigDec) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(MinDec(expected.Abs(), actual.Abs()))
minValue := MinDec(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue)
// fmt.Printf("err term %v\n", errTerm)
if errTerm.GT(BigDecFromSDKDec(e.MultiplicativeTolerance)) {
return comparisonSign
Expand All @@ -141,14 +151,19 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
errTolerance ErrTolerance,
maxIterations int,
) (sdk.Int, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err := f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
var (
curEstimate, curOutput sdk.Int
err error
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}

compRes := errTolerance.Compare(targetOutput, curOutput)
if compRes < 0 {
upperbound = curEstimate
Expand All @@ -157,11 +172,6 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
}

return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand All @@ -182,21 +192,22 @@ type SdkDec[D any] interface {
//
// It binary searches on the input range, until it finds an input y s.t. f(y) meets the err tolerance constraints for how close it is to x.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
func BinarySearchBigDec(f func(input BigDec) BigDec,
lowerbound BigDec,
upperbound BigDec,
targetOutput BigDec,
errTolerance ErrTolerance,
maxIterations int,
) (BigDec, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err := f(curEstimate)
if err != nil {
return BigDec{}, err
}
var (
curEstimate, curOutput BigDec
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput = f(curEstimate)

// fmt.Println("binary search, input, target output, cur output", curEstimate, targetOutput, curOutput)
compRes := errTolerance.CompareBigDec(targetOutput, curOutput)
if compRes < 0 {
Expand All @@ -206,11 +217,6 @@ func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err = f(curEstimate)
if err != nil {
return BigDec{}, err
}
}

return BigDec{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand Down
16 changes: 8 additions & 8 deletions osmomath/binary_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func TestBinarySearch(t *testing.T) {

// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
func lineF(a BigDec) (BigDec, error) {
return a, nil
func lineF(a BigDec) BigDec {
return a
}
func cubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3), nil
func cubicF(a BigDec) BigDec {
return a.PowerInteger(3)
}

var negCubicFConstant BigDec
Expand All @@ -89,11 +89,11 @@ func init() {
negCubicFConstant = NewBigDec(1 << 62).PowerInteger(3).Neg()
}

func negCubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3).Add(negCubicFConstant), nil
func negCubicF(a BigDec) BigDec {
return a.PowerInteger(3).Add(negCubicFConstant)
}

type searchFn func(BigDec) (BigDec, error)
type searchFn func(BigDec) BigDec

type binarySearchTestCase struct {
f searchFn
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestIterationDepthRandValue(t *testing.T) {
errTolerance ErrTolerance, maxNumIters int, errToleranceName string) {
targetF := fnMap[fnName]
targetX := int64(rand.Intn(int(upperbound-lowerbound-1))) + lowerbound + 1
target, _ := targetF(NewBigDec(targetX))
target := targetF(NewBigDec(targetX))
testCase := binarySearchTestCase{
f: lineF,
lowerbound: NewBigDec(lowerbound), upperbound: NewBigDec(upperbound),
Expand Down
4 changes: 2 additions & 2 deletions x/gamm/pool-models/stableswap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def CalcOutAmountGivenExactAmountIn(pool, in_coin, out_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_coin, out_denom, RoundingMode.RoundDown)
in_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundDown)
amm_in = in_amt_scaled * (1 - swap_fee)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, in_amt_scaled)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, amm_in)
out_amt = pool.DescaleToken(out_amt_scaled, out_denom)
return out_amt
```
Expand All @@ -304,7 +304,7 @@ We do this by having `token_in = amm_in / (1 - swapfee)`.
```python
def CalcInAmountGivenExactAmountOut(pool, out_coin, in_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_denom, out_coin, RoundingMode.RoundDown)
out_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundUp)
out_amt_scaled = pool.ScaleToken(out_coin, RoundingMode.RoundUp)

amm_in_scaled = solve_y(out_reserve, in_reserve, remReserves, -out_amt_scaled)
swap_in_scaled = ceil(amm_in_scaled / (1 - swapfee))
Expand Down
35 changes: 13 additions & 22 deletions x/gamm/pool-models/stableswap/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ func cfmmConstant(xReserve, yReserve osmomath.BigDec) osmomath.BigDec {
// We use this version for calculations since the u
// term in the full CFMM is constant.
func cfmmConstantMultiNoV(xReserve, yReserve, wSumSquares osmomath.BigDec) osmomath.BigDec {
if !xReserve.IsPositive() || !yReserve.IsPositive() || wSumSquares.IsNegative() {
panic("invalid input: reserves must be positive")
}
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved

return cfmmConstantMultiNoVY(xReserve, yReserve, wSumSquares).Mul(yReserve)
}

Expand Down Expand Up @@ -263,19 +259,19 @@ func targetKCalculator(x0, y0, w, yf osmomath.BigDec) osmomath.BigDec {

// $$k_{iter}(x_f) = -x_{out}^3 + 3 x_0 x_{out}^2 - (y_f^2 + w + 3x_0^2)x_{out}$$
// where x_out = x_0 - x_f
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) (osmomath.BigDec, error) {
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) osmomath.BigDec {
// compute coefficients first
cubicCoeff := osmomath.OneDec().Neg()
quadraticCoeff := x0.MulInt64(3)
linearCoeff := quadraticCoeff.Mul(x0).Add(w).Add(yf.Mul(yf)).Neg()
return func(xf osmomath.BigDec) (osmomath.BigDec, error) {
return func(xf osmomath.BigDec) osmomath.BigDec {
xOut := x0.Sub(xf)
// horners method
// ax^3 + bx^2 + cx = x(c + x(b + ax))
res := cubicCoeff.Mul(xOut)
res = res.Add(quadraticCoeff).Mul(xOut)
res = res.Add(linearCoeff).Mul(xOut)
return res, nil
return res
}
}

Expand Down Expand Up @@ -487,34 +483,29 @@ func (p *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapF
if !tokensIn.DenomsSubsetOf(p.GetTotalPoolLiquidity(ctx)) {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New("attempted joining pool with assets that do not exist in pool")
}

if len(tokensIn) == 1 && tokensIn[0].Amount.GT(sdk.OneInt()) {
numShares, err = p.calcSingleAssetJoinShares(tokensIn[0], swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

tokensJoined = tokensIn

p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

return numShares, tokensJoined, nil
} else if len(tokensIn) != p.NumAssets() {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New(
"stableswap pool only supports LP'ing with one asset, or all assets in pool")
}
} else {
// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
var remCoins sdk.Coins
numShares, remCoins, err = cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
tokensJoined = tokensIn.Sub(remCoins)
}
p.updatePoolForJoin(tokensIn.Sub(remCoins), numShares)

tokensJoined = tokensIn.Sub(remCoins)
p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
Expand Down
Loading