Skip to content

Commit

Permalink
perf: improve logic of withdraw
Browse files Browse the repository at this point in the history
  • Loading branch information
smol-ninja committed Sep 22, 2024
1 parent 2b3eca7 commit 7a80f4d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
27 changes: 7 additions & 20 deletions src/SablierFlow.sol
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,6 @@ contract SablierFlow is
revert Errors.SablierFlow_Overdraw(streamId, withdrawAmount, withdrawableAmount);
}

uint128 ongoingDebt;

// If the withdraw amount is less than the snapshot debt, use the snapshot debt as a funding source for the
// withdrawal and leave both the withdraw amount and the ongoing debt unchanged.
//
Expand All @@ -824,13 +822,8 @@ contract SablierFlow is
// Steps:
// - Calculate the difference between the withdraw amount the snapshot debt.
// - Scale the difference up to 18 decimals.
// - Divide it by the rate per second, which is also an 18-decimal number, and obtain the time it would take to
// stream the difference at the current rate per second.
// - Add the resultant value to the snapshot time.
// - Calculate the scaled ongoing debt and the new snapshot time.
// - Set the snapshot debt to zero.
// - Recalculate the ongoing debt based on the new snapshot time.
// - Set the withdraw amount to the initial total debt minus the ongoing debt. This may result in a value less
// than the initial withdraw amount.
//
// Note: the rate per second cannot be zero because this can only happen when the stream is paused. In that
// case, the `if` condition will be executed.
Expand All @@ -840,14 +833,14 @@ contract SablierFlow is
difference = withdrawAmount - _streams[streamId].snapshotDebt;
}
uint128 scaledDifference = difference * scaleFactor;
_streams[streamId].snapshotTime += uint40(scaledDifference / rps);
uint128 scaledOngoingDebt =
rps * (uint40(block.timestamp) - _streams[streamId].snapshotTime) - scaledDifference;
_streams[streamId].snapshotTime = uint40(block.timestamp) - uint40(scaledOngoingDebt / rps);

// Set the snapshot debt to zero.
_streams[streamId].snapshotDebt = 0;
uint256 remainderDebt =
scaledOngoingDebt + rps * _streams[streamId].snapshotTime - rps * uint40(block.timestamp);

// Adjust the withdraw amount. At this point, new total debt == ongoing debt.
ongoingDebt = _ongoingDebtOf(streamId);
withdrawAmount = initialTotalDebt - ongoingDebt;
_streams[streamId].snapshotDebt = (remainderDebt / scaleFactor).toUint128();
}

// Effect: update the stream balance.
Expand Down Expand Up @@ -876,12 +869,6 @@ contract SablierFlow is
// Interaction: perform the ERC-20 transfer.
token.safeTransfer({ to: to, value: netWithdrawnAmount });

unchecked {
// Protocol Invariant: the difference between total debts should be equal to the difference between stream
// balances.
assert(initialTotalDebt - _totalDebtOf(streamId) == initialBalance - _streams[streamId].balance);
}

// Log the withdrawal.
emit ISablierFlow.WithdrawFromFlowStream({
streamId: streamId,
Expand Down
6 changes: 5 additions & 1 deletion test/fork/Flow.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ contract Flow_Fork_Test is Fork_Test {
amountWithdrawn,
"token balance == amount withdrawn"
);
assertEq(vars.initialTotalDebt - flow.totalDebtOf(streamId), amountWithdrawn, "total debt == amount withdrawn");
assertLe(
(vars.initialTotalDebt - flow.totalDebtOf(streamId)) - amountWithdrawn,
1,
"total debt - amount withdrawn <= 1"
);
assertEq(
vars.initialStreamBalance - flow.getBalance(streamId), amountWithdrawn, "stream balance == amount withdrawn"
);
Expand Down
6 changes: 4 additions & 2 deletions test/integration/concrete/withdraw/withdraw.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,10 @@ contract Withdraw_Integration_Concrete_Test is Integration_Test {
actualWithdrawnAmount - vars.feeAmount,
"token balance == amount withdrawn - fee amount"
);
assertEq(
vars.initialTotalDebt - flow.totalDebtOf(streamId), actualWithdrawnAmount, "total debt == amount withdrawn"
assertLe(
(vars.initialTotalDebt - flow.totalDebtOf(streamId)) - actualWithdrawnAmount,
1,
"total debt - amount withdrawn <= 1"
);
assertEq(
vars.initialStreamBalance - flow.getBalance(streamId),
Expand Down
6 changes: 5 additions & 1 deletion test/integration/fuzz/withdraw.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ contract Withdraw_Integration_Fuzz_Test is Shared_Integration_Fuzz_Test {
amountWithdrawn - vars.feeAmount,
"token balance == amount withdrawn - fee amount"
);
assertEq(vars.initialTotalDebt - flow.totalDebtOf(streamId), amountWithdrawn, "total debt == amount withdrawn");
assertLe(
(vars.initialTotalDebt - flow.totalDebtOf(streamId)) - amountWithdrawn,
1,
"total debt - amount withdrawn <= 1"
);
assertEq(
vars.initialStreamBalance - flow.getBalance(streamId), amountWithdrawn, "stream balance == amount withdrawn"
);
Expand Down
2 changes: 1 addition & 1 deletion test/integration/fuzz/withdrawMax.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ contract WithdrawMax_Integration_Fuzz_Test is Shared_Integration_Fuzz_Test {

// Check the states after the withdrawal.
assertEq(tokenBalance - token.balanceOf(address(flow)), amountWithdrawn, "token balance == amount withdrawn");
assertEq(totalDebt - flow.totalDebtOf(streamId), amountWithdrawn, "total debt == amount withdrawn");
assertLe((totalDebt - flow.totalDebtOf(streamId)) - amountWithdrawn, 1, "total debt - amount withdrawn <= 1");
assertEq(streamBalance - flow.getBalance(streamId), amountWithdrawn, "stream balance == amount withdrawn");
assertEq(token.balanceOf(withdrawTo) - userBalance, amountWithdrawn, "user balance == token balance");

Expand Down

0 comments on commit 7a80f4d

Please sign in to comment.