Skip to content

Commit

Permalink
Add broadcast API to Go forge scripts (#11826)
Browse files Browse the repository at this point in the history
* Add broadcast API to Go forge scripts

Adds a hooks-based API to collect transactions broadcasted via `vm.broadcast(*)` in the Go-based Forge scripts. Users pass an `OnBroadcast` hook to the host, which will be called with a `Broadcast` struct with the following fields whenever a transaction needs to be emitted:

```go
type Broadcast struct {
	From     common.Address
	To       common.Address
	Calldata []byte
	Value    *big.Int
}
```

This API lets us layer on custom transaction management in the future which will be helpful for `op-deployer`.

As part of this PR, I changed the internal `callStack` data structure to contain pointers to `CallFrame`s rather than passing by value. I discovered a bug where the pranked sender was not being cleared in subsequent calls due to an ineffectual assignment error. I took a look at the implementation and there are many places where assignments to call frames within the stack happen after converting the value to a reference, so converting the stack to store pointers in the first place both simplified the code and eliminated a class of errors in the future. I updated the public API methods to return copies of the internal structs to prevent accidental mutation.

* Code review updates

* moar review updates

* fix bug with staticcall
  • Loading branch information
mslipper authored Sep 10, 2024
1 parent e2356c3 commit 219ebe0
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 12 deletions.
40 changes: 37 additions & 3 deletions op-chain-ops/script/prank.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package script

import (
"bytes"
"errors"
"math/big"

Expand Down Expand Up @@ -69,7 +70,7 @@ func (h *Host) Prank(msgSender *common.Address, txOrigin *common.Address, repeat
h.log.Warn("no call stack")
return nil // cannot prank while not in a call.
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast && !broadcast {
return errors.New("you have an active broadcast; broadcasting and pranks are not compatible")
Expand Down Expand Up @@ -98,7 +99,7 @@ func (h *Host) StopPrank(broadcast bool) error {
if len(h.callStack) == 0 {
return nil
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank == nil {
if broadcast {
return errors.New("no broadcast in progress to stop")
Expand Down Expand Up @@ -127,7 +128,7 @@ func (h *Host) CallerMode() CallerMode {
if len(h.callStack) == 0 {
return CallerModeNone
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast {
if cf.Prank.Repeat {
Expand Down Expand Up @@ -157,3 +158,36 @@ const (
CallerModePrank
CallerModeRecurrentPrank
)

// Broadcast captures a transaction that was selected to be broadcasted
// via vm.broadcast(). Actually submitting the transaction is left up
// to other tools.
type Broadcast struct {
From common.Address
To common.Address
Calldata []byte
Value *big.Int
}

// NewBroadcastFromCtx creates a Broadcast from a VM context. This method
// is preferred to manually creating the struct since it correctly handles
// data that must be copied prior to being returned to prevent accidental
// mutation.
func NewBroadcastFromCtx(ctx *vm.ScopeContext) Broadcast {
// Consistently return nil for zero values in order
// for tests to have a deterministic value to compare
// against.
value := ctx.CallValue().ToBig()
if value.Cmp(common.Big0) == 0 {
value = nil
}

// Need to clone CallInput() below since it's used within
// the VM itself elsewhere.
return Broadcast{
From: ctx.Caller(),
To: ctx.Address(),
Calldata: bytes.Clone(ctx.CallInput()),
Value: value,
}
}
55 changes: 49 additions & 6 deletions op-chain-ops/script/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
Expand Down Expand Up @@ -69,7 +70,7 @@ type Host struct {

precompiles map[common.Address]vm.PrecompiledContract

callStack []CallFrame
callStack []*CallFrame

// serializerStates are in-progress JSON payloads by name,
// for the serializeX family of cheat codes, see:
Expand All @@ -86,12 +87,34 @@ type Host struct {
srcMaps map[common.Address]*srcmap.SourceMap

onLabel []func(name string, addr common.Address)

hooks *Hooks
}

type HostOption func(h *Host)

type BroadcastHook func(broadcast Broadcast)

type Hooks struct {
OnBroadcast BroadcastHook
}

func WithBroadcastHook(hook BroadcastHook) HostOption {
return func(h *Host) {
h.hooks.OnBroadcast = hook
}
}

// NewHost creates a Host that can load contracts from the given Artifacts FS,
// and with an EVM initialized to the given executionContext.
// Optionally src-map loading may be enabled, by providing a non-nil srcFS to read sources from.
func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
func NewHost(
logger log.Logger,
fs *foundry.ArtifactsFS,
srcFS *foundry.SourceMapFS,
executionContext Context,
options ...HostOption,
) *Host {
h := &Host{
log: logger,
af: fs,
Expand All @@ -101,6 +124,13 @@ func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMa
precompiles: make(map[common.Address]vm.PrecompiledContract),
srcFS: srcFS,
srcMaps: make(map[common.Address]*srcmap.SourceMap),
hooks: &Hooks{
OnBroadcast: func(broadcast Broadcast) {},
},
}

for _, opt := range options {
opt(h)
}

// Init a default chain config, with all the mainnet L1 forks activated
Expand Down Expand Up @@ -361,6 +391,19 @@ func (h *Host) unwindCallstack(depth int) {
if len(h.callStack) > 1 {
parentCallFrame := h.callStack[len(h.callStack)-2]
if parentCallFrame.Prank != nil {
if parentCallFrame.Prank.Broadcast && parentCallFrame.LastOp != vm.STATICCALL {
currentFrame := h.callStack[len(h.callStack)-1]
bcast := NewBroadcastFromCtx(currentFrame.Ctx)
h.hooks.OnBroadcast(bcast)
h.log.Debug(
"called broadcast hook",
"from", bcast.From,
"to", bcast.To,
"calldata", hex.EncodeToString(bcast.Calldata),
"value", bcast.Value,
)
}

// While going back to the parent, restore the tx.origin.
// It will later be re-applied on sub-calls if the prank persists (if Repeat == true).
if parentCallFrame.Prank.Origin != nil {
Expand All @@ -372,7 +415,7 @@ func (h *Host) unwindCallstack(depth int) {
}
}
// Now pop the call-frame
h.callStack[len(h.callStack)-1] = CallFrame{} // don't hold on to the underlying call-frame resources
h.callStack[len(h.callStack)-1] = nil // don't hold on to the underlying call-frame resources
h.callStack = h.callStack[:len(h.callStack)-1]
}
}
Expand All @@ -384,7 +427,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
// Check if we are entering a new depth, add it to the call-stack if so.
// We do this here, instead of onEnter, to capture an initialized scope.
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Depth < depth {
h.callStack = append(h.callStack, CallFrame{
h.callStack = append(h.callStack, &CallFrame{
Depth: depth,
LastOp: vm.OpCode(op),
LastPC: pc,
Expand All @@ -395,7 +438,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Ctx != scopeCtx {
panic("scope context changed without call-frame pop/push")
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if vm.OpCode(op) == vm.JUMPDEST { // remember the last PC before successful jump
cf.LastJumps = append(cf.LastJumps, cf.LastPC)
if len(cf.LastJumps) > jumpHistory {
Expand Down Expand Up @@ -429,7 +472,7 @@ func (h *Host) CurrentCall() CallFrame {
if len(h.callStack) == 0 {
return CallFrame{}
}
return h.callStack[len(h.callStack)-1]
return *h.callStack[len(h.callStack)-1]
}

// MsgSender returns the msg.sender of the current active EVM call-frame,
Expand Down
59 changes: 59 additions & 0 deletions op-chain-ops/script/script_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package script

import (
"fmt"
"strings"
"testing"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"

"github.com/holiman/uint256"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -35,3 +40,57 @@ func TestScript(t *testing.T) {
// and a second time, to see if we can revisit the host state.
require.NoError(t, h.cheatcodes.Precompile.DumpState("noop"))
}

func TestScriptBroadcast(t *testing.T) {
logger := testlog.Logger(t, log.LevelDebug)
af := foundry.OpenArtifactsDir("./testdata/test-artifacts")

mustEncodeCalldata := func(method, input string) []byte {
packer, err := abi.JSON(strings.NewReader(fmt.Sprintf(`[{"type":"function","name":"%s","inputs":[{"type":"string","name":"input"}]}]`, method)))
require.NoError(t, err)

data, err := packer.Pack(method, input)
require.NoError(t, err)
return data
}

senderAddr := common.HexToAddress("0x5b73C5498c1E3b4dbA84de0F1833c4a029d90519")
expBroadcasts := []Broadcast{
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "single_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "startstop_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call2", "startstop_call2"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("nested1", "nested"),
},
}

scriptContext := DefaultContext
var broadcasts []Broadcast
hook := func(broadcast Broadcast) {
broadcasts = append(broadcasts, broadcast)
}
h := NewHost(logger, af, nil, scriptContext, WithBroadcastHook(hook))
addr, err := h.LoadContract("ScriptExample.s.sol", "ScriptExample")
require.NoError(t, err)

require.NoError(t, h.EnableCheats())

input := bytes4("runBroadcast()")
returnData, _, err := h.Call(scriptContext.Sender, addr, input[:], DefaultFoundryGasLimit, uint256.NewInt(0))
require.NoError(t, err, "call failed: %x", string(returnData))
require.EqualValues(t, expBroadcasts, broadcasts)
}
51 changes: 51 additions & 0 deletions op-chain-ops/script/testdata/scripts/ScriptExample.s.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ interface Vm {
function parseJsonKeys(string calldata json, string calldata key) external pure returns (string[] memory keys);
function startPrank(address msgSender) external;
function stopPrank() external;
function broadcast() external;
function startBroadcast() external;
function stopBroadcast() external;
}

// console is a minimal version of the console2 lib.
Expand Down Expand Up @@ -64,6 +67,9 @@ contract ScriptExample {
address internal constant VM_ADDRESS = address(uint160(uint256(keccak256("hevm cheat code"))));
Vm internal constant vm = Vm(VM_ADDRESS);

// @notice counter variable to force non-pure calls.
uint256 public counter;

/// @notice example function, runs through basic cheat-codes and console logs.
function run() public {
bool x = vm.envOr("EXAMPLE_BOOL", false);
Expand All @@ -90,9 +96,54 @@ contract ScriptExample {
console.log("done!");
}

/// @notice example function, to test vm.broadcast with.
function runBroadcast() public {
console.log("testing single");
vm.broadcast();
this.call1("single_call1");
this.call2("single_call2");

console.log("testing start/stop");
vm.startBroadcast();
this.call1("startstop_call1");
this.call2("startstop_call2");
this.callPure("startstop_pure");
vm.stopBroadcast();
this.call1("startstop_call3");

console.log("testing nested");
vm.startBroadcast();
this.nested1("nested");
vm.stopBroadcast();
}

/// @notice example external function, to force a CALL, and test vm.startPrank with.
function hello(string calldata _v) external view {
console.log(_v);
console.log("hello msg.sender", address(msg.sender));
}

function call1(string calldata _v) external {
counter++;
console.log(_v);
}

function call2(string calldata _v) external {
counter++;
console.log(_v);
}

function nested1(string calldata _v) external {
counter++;
this.nested2(_v);
}

function nested2(string calldata _v) external {
counter++;
console.log(_v);
}

function callPure(string calldata _v) external pure {
console.log(_v);
}
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

0 comments on commit 219ebe0

Please sign in to comment.