Skip to content

Commit

Permalink
Gamm stableswap improvements (#3839)
Browse files Browse the repository at this point in the history
* Stableswap: README pseudocode fixes

* Stableswap: Binary search code improvement

* Stableswap: minor code improvements

* Stableswap: minor code improvements 2

* Binary search: Check potential division by zero

(cherry picked from commit 2ac5d35)

# Conflicts:
#	x/gamm/pool-models/stableswap/util_test.go
  • Loading branch information
ljah8 authored and mergify[bot] committed Jan 6, 2023
1 parent 961738a commit 1c14a18
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 134 deletions.
56 changes: 31 additions & 25 deletions osmomath/binary_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(sdk.MinInt(expected.Abs(), actual.Abs()).ToDec())
minValue := sdk.MinInt(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue.ToDec())
if errTerm.GT(e.MultiplicativeTolerance) {
return comparisonSign
}
Expand Down Expand Up @@ -121,7 +126,12 @@ func (e ErrTolerance) CompareBigDec(expected BigDec, actual BigDec) int {
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(MinDec(expected.Abs(), actual.Abs()))
minValue := MinDec(expected.Abs(), actual.Abs())
if minValue.IsZero() {
return comparisonSign
}

errTerm := diff.Quo(minValue)
// fmt.Printf("err term %v\n", errTerm)
if errTerm.GT(BigDecFromSDKDec(e.MultiplicativeTolerance)) {
return comparisonSign
Expand All @@ -141,14 +151,19 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
errTolerance ErrTolerance,
maxIterations int,
) (sdk.Int, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err := f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
var (
curEstimate, curOutput sdk.Int
err error
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}

compRes := errTolerance.Compare(targetOutput, curOutput)
if compRes < 0 {
upperbound = curEstimate
Expand All @@ -157,11 +172,6 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
}

return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand All @@ -182,21 +192,22 @@ type SdkDec[D any] interface {
//
// It binary searches on the input range, until it finds an input y s.t. f(y) meets the err tolerance constraints for how close it is to x.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
func BinarySearchBigDec(f func(input BigDec) BigDec,
lowerbound BigDec,
upperbound BigDec,
targetOutput BigDec,
errTolerance ErrTolerance,
maxIterations int,
) (BigDec, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err := f(curEstimate)
if err != nil {
return BigDec{}, err
}
var (
curEstimate, curOutput BigDec
)

curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput = f(curEstimate)

// fmt.Println("binary search, input, target output, cur output", curEstimate, targetOutput, curOutput)
compRes := errTolerance.CompareBigDec(targetOutput, curOutput)
if compRes < 0 {
Expand All @@ -206,11 +217,6 @@ func BinarySearchBigDec(f func(input BigDec) (BigDec, error),
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).Quo(NewBigDec(2))
curOutput, err = f(curEstimate)
if err != nil {
return BigDec{}, err
}
}

return BigDec{}, errors.New("hit maximum iterations, did not converge fast enough")
Expand Down
16 changes: 8 additions & 8 deletions osmomath/binary_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func TestBinarySearch(t *testing.T) {

// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
func lineF(a BigDec) (BigDec, error) {
return a, nil
func lineF(a BigDec) BigDec {
return a
}
func cubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3), nil
func cubicF(a BigDec) BigDec {
return a.PowerInteger(3)
}

var negCubicFConstant BigDec
Expand All @@ -89,11 +89,11 @@ func init() {
negCubicFConstant = NewBigDec(1 << 62).PowerInteger(3).Neg()
}

func negCubicF(a BigDec) (BigDec, error) {
return a.PowerInteger(3).Add(negCubicFConstant), nil
func negCubicF(a BigDec) BigDec {
return a.PowerInteger(3).Add(negCubicFConstant)
}

type searchFn func(BigDec) (BigDec, error)
type searchFn func(BigDec) BigDec

type binarySearchTestCase struct {
f searchFn
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestIterationDepthRandValue(t *testing.T) {
errTolerance ErrTolerance, maxNumIters int, errToleranceName string) {
targetF := fnMap[fnName]
targetX := int64(rand.Intn(int(upperbound-lowerbound-1))) + lowerbound + 1
target, _ := targetF(NewBigDec(targetX))
target := targetF(NewBigDec(targetX))
testCase := binarySearchTestCase{
f: lineF,
lowerbound: NewBigDec(lowerbound), upperbound: NewBigDec(upperbound),
Expand Down
4 changes: 2 additions & 2 deletions x/gamm/pool-models/stableswap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def CalcOutAmountGivenExactAmountIn(pool, in_coin, out_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_coin, out_denom, RoundingMode.RoundDown)
in_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundDown)
amm_in = in_amt_scaled * (1 - swap_fee)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, in_amt_scaled)
out_amt_scaled = solve_y(in_reserve, out_reserve, remReserves, amm_in)
out_amt = pool.DescaleToken(out_amt_scaled, out_denom)
return out_amt
```
Expand All @@ -308,7 +308,7 @@ We do this by having `token_in = amm_in / (1 - swapfee)`.
```python
def CalcInAmountGivenExactAmountOut(pool, out_coin, in_denom, swap_fee):
in_reserve, out_reserve, rem_reserves = pool.ScaledLiquidity(in_denom, out_coin, RoundingMode.RoundDown)
out_amt_scaled = pool.ScaleToken(in_coin, RoundingMode.RoundUp)
out_amt_scaled = pool.ScaleToken(out_coin, RoundingMode.RoundUp)

amm_in_scaled = solve_y(out_reserve, in_reserve, remReserves, -out_amt_scaled)
swap_in_scaled = ceil(amm_in_scaled / (1 - swapfee))
Expand Down
35 changes: 13 additions & 22 deletions x/gamm/pool-models/stableswap/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ func cfmmConstant(xReserve, yReserve osmomath.BigDec) osmomath.BigDec {
// We use this version for calculations since the u
// term in the full CFMM is constant.
func cfmmConstantMultiNoV(xReserve, yReserve, wSumSquares osmomath.BigDec) osmomath.BigDec {
if !xReserve.IsPositive() || !yReserve.IsPositive() || wSumSquares.IsNegative() {
panic("invalid input: reserves must be positive")
}

return cfmmConstantMultiNoVY(xReserve, yReserve, wSumSquares).Mul(yReserve)
}

Expand Down Expand Up @@ -263,19 +259,19 @@ func targetKCalculator(x0, y0, w, yf osmomath.BigDec) osmomath.BigDec {

// $$k_{iter}(x_f) = -x_{out}^3 + 3 x_0 x_{out}^2 - (y_f^2 + w + 3x_0^2)x_{out}$$
// where x_out = x_0 - x_f
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) (osmomath.BigDec, error) {
func iterKCalculator(x0, w, yf osmomath.BigDec) func(osmomath.BigDec) osmomath.BigDec {
// compute coefficients first
cubicCoeff := osmomath.OneDec().Neg()
quadraticCoeff := x0.MulInt64(3)
linearCoeff := quadraticCoeff.Mul(x0).Add(w).Add(yf.Mul(yf)).Neg()
return func(xf osmomath.BigDec) (osmomath.BigDec, error) {
return func(xf osmomath.BigDec) osmomath.BigDec {
xOut := x0.Sub(xf)
// horners method
// ax^3 + bx^2 + cx = x(c + x(b + ax))
res := cubicCoeff.Mul(xOut)
res = res.Add(quadraticCoeff).Mul(xOut)
res = res.Add(linearCoeff).Mul(xOut)
return res, nil
return res
}
}

Expand Down Expand Up @@ -488,34 +484,29 @@ func (p *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapF
if !tokensIn.DenomsSubsetOf(p.GetTotalPoolLiquidity(ctx)) {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New("attempted joining pool with assets that do not exist in pool")
}

if len(tokensIn) == 1 && tokensIn[0].Amount.GT(sdk.OneInt()) {
numShares, err = p.calcSingleAssetJoinShares(tokensIn[0], swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

tokensJoined = tokensIn

p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

return numShares, tokensJoined, nil
} else if len(tokensIn) != p.NumAssets() {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New(
"stableswap pool only supports LP'ing with one asset, or all assets in pool")
}
} else {
// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
var remCoins sdk.Coins
numShares, remCoins, err = cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

// Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap
numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
tokensJoined = tokensIn.Sub(remCoins)
}
p.updatePoolForJoin(tokensIn.Sub(remCoins), numShares)

tokensJoined = tokensIn.Sub(remCoins)
p.updatePoolForJoin(tokensJoined, numShares)

if err = validatePoolLiquidity(p.PoolLiquidity, p.ScalingFactors); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
Expand Down
Loading

0 comments on commit 1c14a18

Please sign in to comment.