Skip to content

Commit

Permalink
refactor: remove cache context from swap in given out (#5198)
Browse files Browse the repository at this point in the history
* remove cache context from in given out

* Update x/concentrated-liquidity/swaps_test.go

Co-authored-by: Roman <[email protected]>

* Update x/concentrated-liquidity/swaps_test.go

Co-authored-by: Roman <[email protected]>

---------

Co-authored-by: Roman <[email protected]>
  • Loading branch information
2 people authored and pysel committed Jun 6, 2023
1 parent 46b5160 commit 201ffd4
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 48 deletions.
25 changes: 20 additions & 5 deletions x/concentrated-liquidity/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ func (k Keeper) SendCoinsBetweenPoolAndUser(ctx sdk.Context, denom0, denom1 stri
return k.sendCoinsBetweenPoolAndUser(ctx, denom0, denom1, amount0, amount1, sender, receiver)
}

func (k Keeper) CalcInAmtGivenOutInternal(ctx sdk.Context, desiredTokenOut sdk.Coin, tokenInDenom string, swapFee sdk.Dec, priceLimit sdk.Dec, poolId uint64) (writeCtx func(), tokenIn, tokenOut sdk.Coin, updatedTick int64, updatedLiquidity, updatedSqrtPrice sdk.Dec, err error) {
return k.calcInAmtGivenOut(ctx, desiredTokenOut, tokenInDenom, swapFee, priceLimit, poolId)
}

func (k Keeper) SwapOutAmtGivenIn(
ctx sdk.Context,
sender sdk.AccAddress,
Expand All @@ -78,10 +74,29 @@ func (k Keeper) ComputeOutAmtGivenIn(
return k.computeOutAmtGivenIn(ctx, poolId, tokenInMin, tokenOutDenom, swapFee, priceLimit)
}

func (k *Keeper) SwapInAmtGivenOut(ctx sdk.Context, sender sdk.AccAddress, pool types.ConcentratedPoolExtension, desiredTokenOut sdk.Coin, tokenInDenom string, swapFee sdk.Dec, priceLimit sdk.Dec) (calcTokenIn, calcTokenOut sdk.Coin, currentTick int64, liquidity, sqrtPrice sdk.Dec, err error) {
func (k Keeper) SwapInAmtGivenOut(
ctx sdk.Context,
sender sdk.AccAddress,
pool types.ConcentratedPoolExtension,
desiredTokenOut sdk.Coin,
tokenInDenom string,
swapFee sdk.Dec,
priceLimit sdk.Dec) (calcTokenIn, calcTokenOut sdk.Coin, currentTick int64, liquidity, sqrtPrice sdk.Dec, err error) {
return k.swapInAmtGivenOut(ctx, sender, pool, desiredTokenOut, tokenInDenom, swapFee, priceLimit)
}

func (k Keeper) ComputeInAmtGivenOut(
ctx sdk.Context,
desiredTokenOut sdk.Coin,
tokenInDenom string,
swapFee sdk.Dec,
priceLimit sdk.Dec,
poolId uint64,

) (calcTokenIn, calcTokenOut sdk.Coin, currentTick int64, liquidity, sqrtPrice sdk.Dec, err error) {
return k.computeInAmtGivenOut(ctx, desiredTokenOut, tokenInDenom, swapFee, priceLimit, poolId)
}

func (k Keeper) InitOrUpdateTick(ctx sdk.Context, poolId uint64, currentTick int64, tickIndex int64, liquidityIn sdk.Dec, upper bool) (err error) {
return k.initOrUpdateTick(ctx, poolId, currentTick, tickIndex, liquidityIn, upper)
}
Expand Down
44 changes: 19 additions & 25 deletions x/concentrated-liquidity/swaps.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (k *Keeper) swapInAmtGivenOut(
swapFee sdk.Dec,
priceLimit sdk.Dec,
) (calcTokenIn, calcTokenOut sdk.Coin, currentTick int64, liquidity, sqrtPrice sdk.Dec, err error) {
writeCtx, tokenIn, tokenOut, newCurrentTick, newLiquidity, newSqrtPrice, err := k.calcInAmtGivenOut(ctx, desiredTokenOut, tokenInDenom, swapFee, priceLimit, pool.GetId())
tokenIn, tokenOut, newCurrentTick, newLiquidity, newSqrtPrice, err := k.computeInAmtGivenOut(ctx, desiredTokenOut, tokenInDenom, swapFee, priceLimit, pool.GetId())
if err != nil {
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
Expand All @@ -197,11 +197,6 @@ func (k *Keeper) swapInAmtGivenOut(
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.InvalidAmountCalculatedError{Amount: tokenIn.Amount}
}

// N.B. making the call below ensures that any mutations done inside calcInAmtGivenOut
// are written to store. If this call were skipped, calcInAmtGivenOut would be non-mutative.
// An example of a store write done in calcInAmtGivenOut is updating ticks as we cross them.
writeCtx()

// Settles balances between the tx sender and the pool to match the swap that was executed earlier.
// Also emits swap event and updates related liquidity metrics
if err := k.updatePoolForSwap(ctx, pool, sender, tokenIn, tokenOut, newCurrentTick, newLiquidity, newSqrtPrice); err != nil {
Expand Down Expand Up @@ -234,7 +229,7 @@ func (k Keeper) CalcInAmtGivenOut(
swapFee sdk.Dec,
) (tokenIn sdk.Coin, err error) {
cacheCtx, _ := ctx.CacheContext()
_, tokenIn, _, _, _, _, err = k.calcInAmtGivenOut(cacheCtx, tokenOut, tokenInDenom, swapFee, sdk.ZeroDec(), poolI.GetId())
tokenIn, _, _, _, _, err = k.computeInAmtGivenOut(cacheCtx, tokenOut, tokenInDenom, swapFee, sdk.ZeroDec(), poolI.GetId())
if err != nil {
return sdk.Coin{}, err
}
Expand Down Expand Up @@ -412,22 +407,21 @@ func (k Keeper) computeOutAmtGivenIn(
return tokenIn, tokenOut, swapState.tick, swapState.liquidity, swapState.sqrtPrice, nil
}

// calcInAmtGivenOut calculates tokens to be swapped in given the desired token out and fee deducted. It also returns
// computeInAmtGivenOut calculates tokens to be swapped in given the desired token out and fee deducted. It also returns
// what the updated tick, liquidity, and currentSqrtPrice for the pool would be after this swap.
// Note this method is non-mutative, so the values returned by calcInAmtGivenOut do not get stored
// Instead, we return writeCtx function so that the caller of this method can decide to write the cached ctx to store or not.
func (k Keeper) calcInAmtGivenOut(
// Note this method is mutative, some of the tick and accumulator updates get written to store.
// However, there are no token transfers or pool updates done in this method. These mutations are performed in swapOutAmtGivenIn.
func (k Keeper) computeInAmtGivenOut(
ctx sdk.Context,
desiredTokenOut sdk.Coin,
tokenInDenom string,
swapFee sdk.Dec,
priceLimit sdk.Dec,
poolId uint64,
) (writeCtx func(), tokenIn, tokenOut sdk.Coin, updatedTick int64, updatedLiquidity, updatedSqrtPrice sdk.Dec, err error) {
ctx, writeCtx = ctx.CacheContext()
) (tokenIn, tokenOut sdk.Coin, updatedTick int64, updatedLiquidity, updatedSqrtPrice sdk.Dec, err error) {
p, err := k.getPoolById(ctx, poolId)
if err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
asset0 := p.GetToken0()
asset1 := p.GetToken1()
Expand All @@ -445,7 +439,7 @@ func (k Keeper) calcInAmtGivenOut(
// take provided price limit and turn this into a sqrt price limit since formulas use sqrtPrice
sqrtPriceLimit, err := priceLimit.ApproxSqrt()
if err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("issue calculating square root of price limit")
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("issue calculating square root of price limit")
}

// set the swap strategy
Expand All @@ -455,20 +449,20 @@ func (k Keeper) calcInAmtGivenOut(
curSqrtPrice := p.GetCurrentSqrtPrice()

if err := swapStrategy.ValidateSqrtPrice(sqrtPriceLimit, curSqrtPrice); err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}

// check that the specified tokenOut matches one of the assets in the specified pool
if desiredTokenOut.Denom != asset0 && desiredTokenOut.Denom != asset1 {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.TokenOutDenomNotInPoolError{TokenOutDenom: desiredTokenOut.Denom}
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.TokenOutDenomNotInPoolError{TokenOutDenom: desiredTokenOut.Denom}
}
// check that the specified tokenIn matches one of the assets in the specified pool
if tokenInDenom != asset0 && tokenInDenom != asset1 {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.TokenInDenomNotInPoolError{TokenInDenom: tokenInDenom}
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.TokenInDenomNotInPoolError{TokenInDenom: tokenInDenom}
}
// check that token in and token out are different denominations
if desiredTokenOut.Denom == tokenInDenom {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.DenomDuplicatedError{TokenInDenom: tokenInDenom, TokenOutDenom: desiredTokenOut.Denom}
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.DenomDuplicatedError{TokenInDenom: tokenInDenom, TokenOutDenom: desiredTokenOut.Denom}
}

// initialize swap state with the following parameters:
Expand All @@ -494,13 +488,13 @@ func (k Keeper) calcInAmtGivenOut(
// if no ticks are initialized (no users have created liquidity positions) then we return an error
nextTick, ok := swapStrategy.NextInitializedTick(ctx, poolId, swapState.tick)
if !ok {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("there are no more ticks initialized to fill the swap")
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("there are no more ticks initialized to fill the swap")
}

// utilizing the next initialized tick, we find the corresponding nextPrice (the target price)
_, sqrtPriceNextTick, err := math.TickToSqrtPrice(nextTick)
if err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("could not convert next tick (%v) to nextSqrtPrice", nextTick)
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, fmt.Errorf("could not convert next tick (%v) to nextSqrtPrice", nextTick)
}

sqrtPriceTarget := swapStrategy.GetSqrtTargetPrice(sqrtPriceNextTick)
Expand Down Expand Up @@ -535,7 +529,7 @@ func (k Keeper) calcInAmtGivenOut(
// retrieve the liquidity held in the next closest initialized tick
liquidityNet, err := k.crossTick(ctx, p.GetId(), nextTick, sdk.NewDecCoinFromDec(desiredTokenOut.Denom, swapState.feeGrowthGlobal))
if err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
liquidityNet = swapStrategy.SetLiquidityDeltaSign(liquidityNet)
// update the swapState's liquidity with the new tick's liquidity
Expand All @@ -550,13 +544,13 @@ func (k Keeper) calcInAmtGivenOut(
price := sqrtPrice.Mul(sqrtPrice)
swapState.tick, err = math.PriceToTickRoundDown(price, p.GetTickSpacing())
if err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
}
}

if err := k.chargeFee(ctx, poolId, sdk.NewDecCoinFromDec(tokenInDenom, swapState.feeGrowthGlobal)); err != nil {
return writeCtx, sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}

// coin amounts require int values
Expand All @@ -571,7 +565,7 @@ func (k Keeper) calcInAmtGivenOut(
tokenIn = sdk.NewCoin(tokenInDenom, amt0)
tokenOut = sdk.NewCoin(desiredTokenOut.Denom, amt1)

return writeCtx, tokenIn, tokenOut, swapState.tick, swapState.liquidity, swapState.sqrtPrice, nil
return tokenIn, tokenOut, swapState.tick, swapState.liquidity, swapState.sqrtPrice, nil
}

// updatePoolForSwap updates the given pool object with the results of a swap operation.
Expand Down
89 changes: 71 additions & 18 deletions x/concentrated-liquidity/swaps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1680,7 +1680,7 @@ func (s *KeeperTestSuite) TestSwapOutAmtGivenIn_TickUpdates() {
}
}

func (s *KeeperTestSuite) TestCalcAndSwapInAmtGivenOut() {
func (s *KeeperTestSuite) TestComputeAndSwapInAmtGivenOut() {
tests := make(map[string]SwapTest, len(swapInGivenOutTestCases)+len(swapInGivenOutFeeTestCases)+len(swapInGivenOutErrorTestCases))
for name, test := range swapInGivenOutTestCases {
tests[name] = test
Expand Down Expand Up @@ -1722,9 +1722,10 @@ func (s *KeeperTestSuite) TestCalcAndSwapInAmtGivenOut() {
poolBeforeCalc, err := s.App.ConcentratedLiquidityKeeper.GetPoolById(s.Ctx, pool.GetId())
s.Require().NoError(err)

// perform calc
_, tokenIn, tokenOut, updatedTick, updatedLiquidity, sqrtPrice, err := s.App.ConcentratedLiquidityKeeper.CalcInAmtGivenOutInternal(
s.Ctx,
// perform compute
cacheCtx, _ := s.Ctx.CacheContext()
tokenIn, tokenOut, updatedTick, updatedLiquidity, sqrtPrice, err := s.App.ConcentratedLiquidityKeeper.ComputeInAmtGivenOut(
cacheCtx,
test.tokenOut, test.tokenInDenom,
test.swapFee, test.priceLimit, pool.GetId())
if test.expectErr {
Expand Down Expand Up @@ -2421,13 +2422,12 @@ func (s *KeeperTestSuite) TestCalcOutAmtGivenIn_NonMutative() {
}
}

// TestCalcInAmtGivenOutWriteCtx tests that writeCtx successfully performs state changes as expected.
// We expect writeCtx to only change fee accum state, since pool state change is not handled via writeCtx function.
func (s *KeeperTestSuite) TestCalcInAmtGivenOutWriteCtx() {
// TestCalcInAmtGivenOut_NonMutative tests that CalcInAmtGivenOut is non-mutative.
func (s *KeeperTestSuite) TestCalcInAmtGivenOut_NonMutative() {
// we only use fee cases here since write Ctx only takes effect in the fee accumulator
tests := make(map[string]SwapTest, len(swapInGivenOutFeeTestCases))
tests := make(map[string]SwapTest, len(swapOutGivenInFeeCases))

for name, test := range swapInGivenOutFeeTestCases {
for name, test := range swapOutGivenInFeeCases {
tests[name] = test
}

Expand Down Expand Up @@ -2459,10 +2459,11 @@ func (s *KeeperTestSuite) TestCalcInAmtGivenOutWriteCtx() {
s.Require().NoError(err)

// perform calc
writeCtx, _, _, _, _, _, err := s.App.ConcentratedLiquidityKeeper.CalcInAmtGivenOutInternal(
_, err = s.App.ConcentratedLiquidityKeeper.CalcOutAmtGivenIn(
s.Ctx,
test.tokenOut, test.tokenInDenom,
test.swapFee, test.priceLimit, pool.GetId())
poolBeforeCalc,
test.tokenIn, test.tokenOutDenom,
test.swapFee)
s.Require().NoError(err)

// check that the pool has not been modified after performing calc
Expand All @@ -2482,18 +2483,70 @@ func (s *KeeperTestSuite) TestCalcInAmtGivenOutWriteCtx() {
s.Require().Equal(1,
additiveFeeGrowthGlobalErrTolerance.CompareBigDec(
osmomath.BigDecFromSDKDec(test.expectedFeeGrowthAccumulatorValue),
osmomath.BigDecFromSDKDec(feeAccum.GetValue().AmountOf(test.tokenInDenom)),
osmomath.BigDecFromSDKDec(feeAccum.GetValue().AmountOf(test.tokenIn.Denom)),
),
)
})
}
}

// System under test
writeCtx()
// TestComputeInAmtGivenOut tests that ComputeInAmtGivenOut successfully performs state changes as expected.
func (s *KeeperTestSuite) TestComputeInAmtGivenOut() {
// we only use fee cases here since write Ctx only takes effect in the fee accumulator
tests := make(map[string]SwapTest, len(swapInGivenOutFeeTestCases))

// now we check that fee accum has been correctly updated upon writeCtx
feeAccum, err = s.App.ConcentratedLiquidityKeeper.GetFeeAccumulator(s.Ctx, 1)
for name, test := range swapInGivenOutFeeTestCases {
tests[name] = test
}

for name, test := range tests {
test := test
s.Run(name, func() {
s.SetupTest()
s.FundAcc(s.TestAccs[0], sdk.NewCoins(sdk.NewCoin(ETH, sdk.NewInt(10000000000000)), sdk.NewCoin(USDC, sdk.NewInt(1000000000000))))
s.FundAcc(s.TestAccs[1], sdk.NewCoins(sdk.NewCoin(ETH, sdk.NewInt(10000000000000)), sdk.NewCoin(USDC, sdk.NewInt(1000000000000))))

// Create default CL pool
pool := s.PrepareConcentratedPool()

// add default position
s.SetupDefaultPosition(pool.GetId())

// add second position depending on the test
if !test.secondPositionLowerPrice.IsNil() {
newLowerTick, err := math.PriceToTickRoundDown(test.secondPositionLowerPrice, pool.GetTickSpacing())
s.Require().NoError(err)
newUpperTick, err := math.PriceToTickRoundDown(test.secondPositionUpperPrice, pool.GetTickSpacing())
s.Require().NoError(err)

_, _, _, _, _, _, _, err = s.App.ConcentratedLiquidityKeeper.CreatePosition(s.Ctx, pool.GetId(), s.TestAccs[1], DefaultCoins, sdk.ZeroInt(), sdk.ZeroInt(), newLowerTick, newUpperTick)
s.Require().NoError(err)
}

poolBeforeCalc, err := s.App.ConcentratedLiquidityKeeper.GetPoolById(s.Ctx, pool.GetId())
s.Require().NoError(err)

feeAccumValue = feeAccum.GetValue()
// perform calc
_, _, _, _, _, err = s.App.ConcentratedLiquidityKeeper.ComputeInAmtGivenOut(
s.Ctx,
test.tokenOut, test.tokenInDenom,
test.swapFee, test.priceLimit, pool.GetId())
s.Require().NoError(err)

// check that the pool has not been modified after performing calc
poolAfterCalc, err := s.App.ConcentratedLiquidityKeeper.GetPoolById(s.Ctx, pool.GetId())
s.Require().NoError(err)

s.Require().Equal(poolBeforeCalc.GetCurrentSqrtPrice(), poolAfterCalc.GetCurrentSqrtPrice())
s.Require().Equal(poolBeforeCalc.GetCurrentTick(), poolAfterCalc.GetCurrentTick())
s.Require().Equal(poolBeforeCalc.GetLiquidity(), poolAfterCalc.GetLiquidity())
s.Require().Equal(poolBeforeCalc.GetTickSpacing(), poolAfterCalc.GetTickSpacing())

// check that fee accum has been correctly updated.
feeAccum, err := s.App.ConcentratedLiquidityKeeper.GetFeeAccumulator(s.Ctx, 1)
s.Require().NoError(err)

feeAccumValue := feeAccum.GetValue()
s.Require().Equal(1, feeAccumValue.Len())
s.Require().Equal(0,
additiveFeeGrowthGlobalErrTolerance.CompareBigDec(
Expand Down

0 comments on commit 201ffd4

Please sign in to comment.