From 3f2f5345261904463f5429c9031c3d2185c0f4fe Mon Sep 17 00:00:00 2001 From: ross <92001561+z0r0z@users.noreply.github.com> Date: Sat, 7 Dec 2024 02:04:56 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20STL:=20totalSupply=20(#1212)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/utils/SafeTransferLib.sol | 19 +++++++++++++++++++ test/SafeTransferLib.t.sol | 12 ++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/utils/SafeTransferLib.sol b/src/utils/SafeTransferLib.sol index cab9f586bc..6944ed3d82 100644 --- a/src/utils/SafeTransferLib.sol +++ b/src/utils/SafeTransferLib.sol @@ -25,6 +25,9 @@ library SafeTransferLib { /// @dev The ERC20 `approve` has failed. error ApproveFailed(); + /// @dev The ERC20 `totalSupply` query has failed. + error TotalSupplyQueryFailed(); + /// @dev The Permit2 operation has failed. error Permit2Failed(); @@ -396,6 +399,22 @@ library SafeTransferLib { } } + /// @dev Returns the total supply of the `token`. + /// Reverts if the token does not exist or does not implement `totalSupply()`. + function totalSupply(address token) internal view returns (uint256 result) { + /// @solidity memory-safe-assembly + assembly { + mstore(0x00, 0x18160ddd) // `totalSupply()`. + if iszero( + and(gt(returndatasize(), 0x1f), staticcall(gas(), token, 0x1c, 0x04, 0x00, 0x20)) + ) { + mstore(0x00, 0x54cd9435) // `TotalSupplyQueryFailed()`. + revert(0x1c, 0x04) + } + result := mload(0x00) + } + } + /// @dev Sends `amount` of ERC20 `token` from `from` to `to`. /// If the initial attempt fails, try to use Permit2 to transfer the token. /// Reverts upon failure. diff --git a/test/SafeTransferLib.t.sol b/test/SafeTransferLib.t.sol index e5c5aaea9a..ab342894f9 100644 --- a/test/SafeTransferLib.t.sol +++ b/test/SafeTransferLib.t.sol @@ -1091,4 +1091,16 @@ contract SafeTransferLibTest is SoladyTest { t.s ); } + + function testTotalSupplyQuery() public { + uint256 totalSupplyBefore = this.totalSupplyQuery(address(erc20)); + erc20.burn(address(this), 123); + assertEq(this.totalSupplyQuery(address(erc20)), totalSupplyBefore - 123); + vm.expectRevert(SafeTransferLib.TotalSupplyQueryFailed.selector); + this.totalSupplyQuery(address(0)); + } + + function totalSupplyQuery(address token) public view returns (uint256) { + return SafeTransferLib.totalSupply(token); + } }