Skip to content

Commit

Permalink
✨ Add directReturn for bytes[] (#1209)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorized authored Dec 7, 2024
1 parent 3f2f534 commit e286c5f
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/utils/LibBytes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,43 @@ library LibBytes {
}
}

/// @dev Directly returns `a` with minimal copying.
function directReturn(bytes[] memory a) internal pure {
assembly {
let n := mload(a) // `a.length`.
let o := add(a, 0x20) // Start of elements in `a`.
let u := a // Highest memory slot.
let w := not(0x1f)
for { let i := 0 } iszero(eq(i, n)) { i := add(i, 1) } {
let c := add(o, shl(5, i)) // Location of pointer to `a[i]`.
let s := mload(c) // `a[i]`.
let l := mload(s) // `a[i].length`.
let r := and(l, 0x1f) // `a[i].length % 32`.
let z := add(0x20, and(l, w)) // Offset of last word in `a[i]` from `s`.
// If `s` comes before `o`, or `s` is not zero right padded.
if iszero(lt(lt(s, o), or(iszero(r), iszero(shl(shl(3, r), mload(add(s, z))))))) {
let m := mload(0x40)
mstore(m, l) // Copy `a[i].length`.
for {} 1 {} {
mstore(add(m, z), mload(add(s, z))) // Copy `a[i]`, backwards.
z := add(z, w) // `sub(z, 0x20)`.
if iszero(z) { break }
}
let e := add(add(m, 0x20), l)
mstore(e, 0) // Zeroize the slot after the copied bytes.
mstore(0x40, add(e, 0x20)) // Allocate memory.
s := m
}
mstore(c, sub(s, o)) // Convert to calldata offset.
let t := add(l, add(s, 0x20))
if iszero(lt(t, u)) { u := t }
}
let retStart := add(a, w) // Assumes `a` doesn't start from scratch space.
mstore(retStart, 0x20) // Store the return offset.
return(retStart, add(0x40, sub(u, retStart))) // End the transaction.
}
}

/// @dev Returns the word at `offset`, without any bounds checks.
/// To load an address, you can use `address(bytes20(load(a, offset)))`.
function load(bytes memory a, uint256 offset) internal pure returns (bytes32 result) {
Expand Down
37 changes: 37 additions & 0 deletions src/utils/g/LibBytes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,43 @@ library LibBytes {
}
}

/// @dev Directly returns `a` with minimal copying.
function directReturn(bytes[] memory a) internal pure {
assembly {
let n := mload(a) // `a.length`.
let o := add(a, 0x20) // Start of elements in `a`.
let u := a // Highest memory slot.
let w := not(0x1f)
for { let i := 0 } iszero(eq(i, n)) { i := add(i, 1) } {
let c := add(o, shl(5, i)) // Location of pointer to `a[i]`.
let s := mload(c) // `a[i]`.
let l := mload(s) // `a[i].length`.
let r := and(l, 0x1f) // `a[i].length % 32`.
let z := add(0x20, and(l, w)) // Offset of last word in `a[i]` from `s`.
// If `s` comes before `o`, or `s` is not zero right padded.
if iszero(lt(lt(s, o), or(iszero(r), iszero(shl(shl(3, r), mload(add(s, z))))))) {
let m := mload(0x40)
mstore(m, l) // Copy `a[i].length`.
for {} 1 {} {
mstore(add(m, z), mload(add(s, z))) // Copy `a[i]`, backwards.
z := add(z, w) // `sub(z, 0x20)`.
if iszero(z) { break }
}
let e := add(add(m, 0x20), l)
mstore(e, 0) // Zeroize the slot after the copied bytes.
mstore(0x40, add(e, 0x20)) // Allocate memory.
s := m
}
mstore(c, sub(s, o)) // Convert to calldata offset.
let t := add(l, add(s, 0x20))
if iszero(lt(t, u)) { u := t }
}
let retStart := add(a, w) // Assumes `a` doesn't start from scratch space.
mstore(retStart, 0x20) // Store the return offset.
return(retStart, add(0x40, sub(u, retStart))) // End the transaction.
}
}

/// @dev Returns the word at `offset`, without any bounds checks.
/// To load an address, you can use `address(bytes20(load(a, offset)))`.
function load(bytes memory a, uint256 offset) internal pure returns (bytes32 result) {
Expand Down
91 changes: 91 additions & 0 deletions test/LibBytes.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,95 @@ contract LibBytesTest is SoladyTest {
function testEmptyCalldata() public {
assertEq(LibBytes.emptyCalldata(), "");
}

function testDirectReturn() public {
uint256 seed = 123;
bytes[] memory expected = _generateBytesArray(seed);
bytes[] memory computed = this.generateBytesArray(seed, false);
unchecked {
for (uint256 i; i != expected.length; ++i) {
_checkMemory(computed[i]);
assertEq(computed[i], expected[i]);
}
assertEq(computed.length, expected.length);
}
}

function testDirectReturn(uint256 seed) public {
bytes[] memory expected = _generateBytesArray(seed);
(bool success, bytes memory encoded) = address(this).call(
abi.encodeWithSignature("generateBytesArray(uint256,bool)", seed, true)
);
assertTrue(success);
bytes[] memory computed;
/// @solidity memory-safe-assembly
assembly {
let o := add(encoded, 0x20)
computed := add(o, mload(o))
for { let i := 0 } lt(i, mload(computed)) { i := add(i, 1) } {
let c := add(add(0x20, computed), shl(5, i))
mstore(c, add(add(0x20, computed), mload(c)))
}
}
unchecked {
for (uint256 i; i != expected.length; ++i) {
_checkMemory(computed[i]);
assertEq(computed[i], expected[i]);
}
assertEq(computed.length, expected.length);
}
if (seed & 0xf == 0) {
assertEq(abi.encode(expected), abi.encode(this.generateBytesArray(seed, true)));
}
}

function generateBytesArray(uint256 seed, bool brutalized)
public
view
returns (bytes[] memory)
{
if (brutalized) {
_misalignFreeMemoryPointer();
_brutalizeMemory();
}
LibBytes.directReturn(_generateBytesArray(seed));
}

function _generateBytesArray(uint256 seed) internal pure returns (bytes[] memory a) {
bytes memory before = "hehe";
/// @solidity memory-safe-assembly
assembly {
mstore(0x00, seed)
mstore(0x20, 0)
function _next() -> _r {
_r := keccak256(0x00, 0x40)
mstore(0x20, _r)
}
function _nextBytes() -> _b {
_b := mload(0x40)
let n_ := and(_next(), 0x7f)
mstore(_b, n_)
for { let i_ := 0 } lt(i_, n_) { i_ := add(i_, 0x20) } {
mstore(add(add(_b, 0x20), i_), _next())
}
if and(1, _next()) {
mstore(0x40, add(n_, add(_b, 0x20)))
leave
}
mstore(add(n_, add(_b, 0x20)), 0)
mstore(0x40, add(n_, add(_b, 0x40)))
}
let n := and(_next(), 7)
a := mload(0x40)
mstore(a, n)
mstore(0x40, add(add(a, 0x20), shl(5, n)))
for { let i := 0 } lt(i, n) { i := add(1, i) } {
if iszero(and(7, _next())) {
mstore(add(add(a, 0x20), shl(5, i)), before)
continue
}
mstore(add(add(a, 0x20), shl(5, i)), _nextBytes())
}
}
}
}

0 comments on commit e286c5f

Please sign in to comment.