Skip to content

Commit

Permalink
Add broadcast API to Go forge scripts
Browse files Browse the repository at this point in the history
Adds a hooks-based API to collect transactions broadcasted via `vm.broadcast(*)` in the Go-based Forge scripts. Users pass an `OnBroadcast` hook to the host, which will be called with a `Broadcast` struct with the following fields whenever a transaction needs to be emitted:

```go
type Broadcast struct {
	From     common.Address
	To       common.Address
	Calldata []byte
	Value    *big.Int
}
```

This API lets us layer on custom transaction management in the future which will be helpful for `op-deployer`.

As part of this PR, I changed the internal `callStack` data structure to contain pointers to `CallFrame`s rather than passing by value. I discovered a bug where the pranked sender was not being cleared in subsequent calls due to an ineffectual assignment error. I took a look at the implementation and there are many places where assignments to call frames within the stack happen after converting the value to a reference, so converting the stack to store pointers in the first place both simplified the code and eliminated a class of errors in the future. I updated the public API methods to return copies of the internal structs to prevent accidental mutation.
  • Loading branch information
mslipper committed Sep 10, 2024
1 parent 8404e91 commit b87c989
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 13 deletions.
6 changes: 3 additions & 3 deletions op-chain-ops/script/prank.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (h *Host) Prank(msgSender *common.Address, txOrigin *common.Address, repeat
h.log.Warn("no call stack")
return nil // cannot prank while not in a call.
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast && !broadcast {
return errors.New("you have an active broadcast; broadcasting and pranks are not compatible")
Expand Down Expand Up @@ -98,7 +98,7 @@ func (h *Host) StopPrank(broadcast bool) error {
if len(h.callStack) == 0 {
return nil
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank == nil {
if broadcast {
return errors.New("no broadcast in progress to stop")
Expand Down Expand Up @@ -127,7 +127,7 @@ func (h *Host) CallerMode() CallerMode {
if len(h.callStack) == 0 {
return CallerModeNone
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast {
if cf.Prank.Repeat {
Expand Down
74 changes: 67 additions & 7 deletions op-chain-ops/script/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
Expand Down Expand Up @@ -54,6 +55,33 @@ type CallFrame struct {
Prank *Prank
}

type Broadcast struct {
From common.Address
To common.Address
Calldata []byte
Value *big.Int
}

var zero = big.NewInt(0)

func NewBroadcastFromCtx(ctx *vm.ScopeContext) Broadcast {
value := ctx.CallValue().ToBig()
if value.Cmp(zero) == 0 {
value = nil
}

callInput := ctx.CallInput()
calldata := make([]byte, len(callInput))
copy(calldata, callInput)

return Broadcast{
From: ctx.Caller(),
To: ctx.Address(),
Calldata: calldata,
Value: value,
}
}

// Host is an EVM executor that runs Forge scripts.
type Host struct {
log log.Logger
Expand All @@ -69,7 +97,7 @@ type Host struct {

precompiles map[common.Address]vm.PrecompiledContract

callStack []CallFrame
callStack []*CallFrame

// serializerStates are in-progress JSON payloads by name,
// for the serializeX family of cheat codes, see:
Expand All @@ -86,12 +114,26 @@ type Host struct {
srcMaps map[common.Address]*srcmap.SourceMap

onLabel []func(name string, addr common.Address)

hooks *Hooks
}

type Hooks struct {
OnBroadcast func(broadcast Broadcast)
}

var defaultHooks = &Hooks{
OnBroadcast: func(broadcast Broadcast) {},
}

func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
return NewHostWithHooks(logger, fs, srcFS, executionContext, nil)
}

// NewHost creates a Host that can load contracts from the given Artifacts FS,
// NewHostWithHooks creates a Host that can load contracts from the given Artifacts FS,
// and with an EVM initialized to the given executionContext.
// Optionally src-map loading may be enabled, by providing a non-nil srcFS to read sources from.
func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
func NewHostWithHooks(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context, hooks *Hooks) *Host {
h := &Host{
log: logger,
af: fs,
Expand All @@ -101,6 +143,11 @@ func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMa
precompiles: make(map[common.Address]vm.PrecompiledContract),
srcFS: srcFS,
srcMaps: make(map[common.Address]*srcmap.SourceMap),
hooks: defaultHooks,
}

if hooks != nil {
h.hooks = hooks
}

// Init a default chain config, with all the mainnet L1 forks activated
Expand Down Expand Up @@ -361,6 +408,19 @@ func (h *Host) unwindCallstack(depth int) {
if len(h.callStack) > 1 {
parentCallFrame := h.callStack[len(h.callStack)-2]
if parentCallFrame.Prank != nil {
if parentCallFrame.Prank.Broadcast {
currentFrame := h.callStack[len(h.callStack)-1]
bcast := NewBroadcastFromCtx(currentFrame.Ctx)
h.hooks.OnBroadcast(bcast)
h.log.Debug(
"called broadcast hook",
"from", bcast.From,
"to", bcast.To,
"calldata", hex.EncodeToString(bcast.Calldata),
"value", bcast.Value,
)
}

// While going back to the parent, restore the tx.origin.
// It will later be re-applied on sub-calls if the prank persists (if Repeat == true).
if parentCallFrame.Prank.Origin != nil {
Expand All @@ -372,7 +432,7 @@ func (h *Host) unwindCallstack(depth int) {
}
}
// Now pop the call-frame
h.callStack[len(h.callStack)-1] = CallFrame{} // don't hold on to the underlying call-frame resources
h.callStack[len(h.callStack)-1] = nil // don't hold on to the underlying call-frame resources
h.callStack = h.callStack[:len(h.callStack)-1]
}
}
Expand All @@ -384,7 +444,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
// Check if we are entering a new depth, add it to the call-stack if so.
// We do this here, instead of onEnter, to capture an initialized scope.
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Depth < depth {
h.callStack = append(h.callStack, CallFrame{
h.callStack = append(h.callStack, &CallFrame{
Depth: depth,
LastOp: vm.OpCode(op),
LastPC: pc,
Expand All @@ -395,7 +455,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Ctx != scopeCtx {
panic("scope context changed without call-frame pop/push")
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if vm.OpCode(op) == vm.JUMPDEST { // remember the last PC before successful jump
cf.LastJumps = append(cf.LastJumps, cf.LastPC)
if len(cf.LastJumps) > jumpHistory {
Expand Down Expand Up @@ -429,7 +489,7 @@ func (h *Host) CurrentCall() CallFrame {
if len(h.callStack) == 0 {
return CallFrame{}
}
return h.callStack[len(h.callStack)-1]
return *h.callStack[len(h.callStack)-1]
}

// MsgSender returns the msg.sender of the current active EVM call-frame,
Expand Down
59 changes: 59 additions & 0 deletions op-chain-ops/script/script_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package script

import (
"fmt"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"strings"
"testing"

"github.com/holiman/uint256"
Expand Down Expand Up @@ -35,3 +39,58 @@ func TestScript(t *testing.T) {
// and a second time, to see if we can revisit the host state.
require.NoError(t, h.cheatcodes.Precompile.DumpState("noop"))
}

func TestScriptBroadcast(t *testing.T) {
logger := testlog.Logger(t, log.LevelDebug)
af := foundry.OpenArtifactsDir("./testdata/test-artifacts")

mustEncodeCalldata := func(method, input string) []byte {
packer, err := abi.JSON(strings.NewReader(fmt.Sprintf(`[{"type":"function","name":"%s","inputs":[{"type":"string","name":"input"}]}]`, method)))
require.NoError(t, err)

data, err := packer.Pack(method, input)
require.NoError(t, err)
return data
}

senderAddr := common.HexToAddress("0x5b73C5498c1E3b4dbA84de0F1833c4a029d90519")
expBroadcasts := []Broadcast{
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "single_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "startstop_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call2", "startstop_call2"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("nested1", "nested"),
},
}

scriptContext := DefaultContext
var broadcasts []Broadcast
h := NewHostWithHooks(logger, af, nil, scriptContext, &Hooks{
OnBroadcast: func(broadcast Broadcast) {
broadcasts = append(broadcasts, broadcast)
},
})
addr, err := h.LoadContract("ScriptExample.s.sol", "ScriptExample")
require.NoError(t, err)

require.NoError(t, h.EnableCheats())

input := bytes4("runBroadcast()")
returnData, _, err := h.Call(scriptContext.Sender, addr, input[:], DefaultFoundryGasLimit, uint256.NewInt(0))
require.NoError(t, err, "call failed: %x", string(returnData))
require.EqualValues(t, expBroadcasts, broadcasts)
}
39 changes: 39 additions & 0 deletions op-chain-ops/script/testdata/scripts/ScriptExample.s.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ interface Vm {
function parseJsonKeys(string calldata json, string calldata key) external pure returns (string[] memory keys);
function startPrank(address msgSender) external;
function stopPrank() external;
function broadcast() external;
function startBroadcast() external;
function stopBroadcast() external;
}

// console is a minimal version of the console2 lib.
Expand Down Expand Up @@ -90,9 +93,45 @@ contract ScriptExample {
console.log("done!");
}

/// @notice example function, to test vm.broadcast with.
function runBroadcast() public {
console.log("testing single");
vm.broadcast();
this.call1("single_call1");
this.call2("single_call2");

console.log("testing start/stop");
vm.startBroadcast();
this.call1("startstop_call1");
this.call2("startstop_call2");
vm.stopBroadcast();
this.call1("startstop_call3");

console.log("testing nested");
vm.startBroadcast();
this.nested1("nested");
vm.stopBroadcast();
}

/// @notice example external function, to force a CALL, and test vm.startPrank with.
function hello(string calldata _v) external view {
console.log(_v);
console.log("hello msg.sender", address(msg.sender));
}

function call1(string calldata _v) external pure {
console.log(_v);
}

function call2(string calldata _v) external pure {
console.log(_v);
}

function nested1(string calldata _v) external view {
this.nested2(_v);
}

function nested2(string calldata _v) external pure {
console.log(_v);
}
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

0 comments on commit b87c989

Please sign in to comment.