Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast API to Go forge scripts #11826

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions op-chain-ops/script/prank.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (h *Host) Prank(msgSender *common.Address, txOrigin *common.Address, repeat
h.log.Warn("no call stack")
return nil // cannot prank while not in a call.
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast && !broadcast {
return errors.New("you have an active broadcast; broadcasting and pranks are not compatible")
Expand Down Expand Up @@ -98,7 +98,7 @@ func (h *Host) StopPrank(broadcast bool) error {
if len(h.callStack) == 0 {
return nil
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank == nil {
if broadcast {
return errors.New("no broadcast in progress to stop")
Expand Down Expand Up @@ -127,7 +127,7 @@ func (h *Host) CallerMode() CallerMode {
if len(h.callStack) == 0 {
return CallerModeNone
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast {
if cf.Prank.Repeat {
Expand Down
74 changes: 67 additions & 7 deletions op-chain-ops/script/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
Expand Down Expand Up @@ -54,6 +55,33 @@ type CallFrame struct {
Prank *Prank
}

type Broadcast struct {
mslipper marked this conversation as resolved.
Show resolved Hide resolved
From common.Address
To common.Address
Calldata []byte
Value *big.Int
}

var zero = big.NewInt(0)
mslipper marked this conversation as resolved.
Show resolved Hide resolved

func NewBroadcastFromCtx(ctx *vm.ScopeContext) Broadcast {
value := ctx.CallValue().ToBig()
if value.Cmp(zero) == 0 {
value = nil
}

callInput := ctx.CallInput()
calldata := make([]byte, len(callInput))
copy(calldata, callInput)

return Broadcast{
From: ctx.Caller(),
To: ctx.Address(),
Calldata: calldata,
Value: value,
}
}

// Host is an EVM executor that runs Forge scripts.
type Host struct {
log log.Logger
Expand All @@ -69,7 +97,7 @@ type Host struct {

precompiles map[common.Address]vm.PrecompiledContract

callStack []CallFrame
callStack []*CallFrame

// serializerStates are in-progress JSON payloads by name,
// for the serializeX family of cheat codes, see:
Expand All @@ -86,12 +114,26 @@ type Host struct {
srcMaps map[common.Address]*srcmap.SourceMap

onLabel []func(name string, addr common.Address)

hooks *Hooks
}

type Hooks struct {
OnBroadcast func(broadcast Broadcast)
}

var defaultHooks = &Hooks{
OnBroadcast: func(broadcast Broadcast) {},
}

func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
return NewHostWithHooks(logger, fs, srcFS, executionContext, nil)
}

// NewHost creates a Host that can load contracts from the given Artifacts FS,
// NewHostWithHooks creates a Host that can load contracts from the given Artifacts FS,
// and with an EVM initialized to the given executionContext.
// Optionally src-map loading may be enabled, by providing a non-nil srcFS to read sources from.
func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
func NewHostWithHooks(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context, hooks *Hooks) *Host {
mslipper marked this conversation as resolved.
Show resolved Hide resolved
h := &Host{
log: logger,
af: fs,
Expand All @@ -101,6 +143,11 @@ func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMa
precompiles: make(map[common.Address]vm.PrecompiledContract),
srcFS: srcFS,
srcMaps: make(map[common.Address]*srcmap.SourceMap),
hooks: defaultHooks,
}

if hooks != nil {
h.hooks = hooks
}

// Init a default chain config, with all the mainnet L1 forks activated
Expand Down Expand Up @@ -361,6 +408,19 @@ func (h *Host) unwindCallstack(depth int) {
if len(h.callStack) > 1 {
parentCallFrame := h.callStack[len(h.callStack)-2]
if parentCallFrame.Prank != nil {
if parentCallFrame.Prank.Broadcast {
currentFrame := h.callStack[len(h.callStack)-1]
bcast := NewBroadcastFromCtx(currentFrame.Ctx)
mslipper marked this conversation as resolved.
Show resolved Hide resolved
h.hooks.OnBroadcast(bcast)
h.log.Debug(
"called broadcast hook",
"from", bcast.From,
"to", bcast.To,
"calldata", hex.EncodeToString(bcast.Calldata),
"value", bcast.Value,
)
}

// While going back to the parent, restore the tx.origin.
// It will later be re-applied on sub-calls if the prank persists (if Repeat == true).
if parentCallFrame.Prank.Origin != nil {
Expand All @@ -372,7 +432,7 @@ func (h *Host) unwindCallstack(depth int) {
}
}
// Now pop the call-frame
h.callStack[len(h.callStack)-1] = CallFrame{} // don't hold on to the underlying call-frame resources
h.callStack[len(h.callStack)-1] = nil // don't hold on to the underlying call-frame resources
h.callStack = h.callStack[:len(h.callStack)-1]
}
}
Expand All @@ -384,7 +444,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
// Check if we are entering a new depth, add it to the call-stack if so.
// We do this here, instead of onEnter, to capture an initialized scope.
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Depth < depth {
h.callStack = append(h.callStack, CallFrame{
h.callStack = append(h.callStack, &CallFrame{
Depth: depth,
LastOp: vm.OpCode(op),
LastPC: pc,
Expand All @@ -395,7 +455,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Ctx != scopeCtx {
panic("scope context changed without call-frame pop/push")
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if vm.OpCode(op) == vm.JUMPDEST { // remember the last PC before successful jump
cf.LastJumps = append(cf.LastJumps, cf.LastPC)
if len(cf.LastJumps) > jumpHistory {
Expand Down Expand Up @@ -429,7 +489,7 @@ func (h *Host) CurrentCall() CallFrame {
if len(h.callStack) == 0 {
return CallFrame{}
}
return h.callStack[len(h.callStack)-1]
return *h.callStack[len(h.callStack)-1]
}

// MsgSender returns the msg.sender of the current active EVM call-frame,
Expand Down
60 changes: 60 additions & 0 deletions op-chain-ops/script/script_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package script

import (
"fmt"
"strings"
"testing"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"

"github.com/holiman/uint256"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -35,3 +40,58 @@ func TestScript(t *testing.T) {
// and a second time, to see if we can revisit the host state.
require.NoError(t, h.cheatcodes.Precompile.DumpState("noop"))
}

func TestScriptBroadcast(t *testing.T) {
logger := testlog.Logger(t, log.LevelDebug)
af := foundry.OpenArtifactsDir("./testdata/test-artifacts")

mustEncodeCalldata := func(method, input string) []byte {
packer, err := abi.JSON(strings.NewReader(fmt.Sprintf(`[{"type":"function","name":"%s","inputs":[{"type":"string","name":"input"}]}]`, method)))
require.NoError(t, err)

data, err := packer.Pack(method, input)
require.NoError(t, err)
return data
}

senderAddr := common.HexToAddress("0x5b73C5498c1E3b4dbA84de0F1833c4a029d90519")
expBroadcasts := []Broadcast{
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "single_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "startstop_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call2", "startstop_call2"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("nested1", "nested"),
},
}

scriptContext := DefaultContext
var broadcasts []Broadcast
h := NewHostWithHooks(logger, af, nil, scriptContext, &Hooks{
OnBroadcast: func(broadcast Broadcast) {
broadcasts = append(broadcasts, broadcast)
},
})
addr, err := h.LoadContract("ScriptExample.s.sol", "ScriptExample")
require.NoError(t, err)

require.NoError(t, h.EnableCheats())

input := bytes4("runBroadcast()")
returnData, _, err := h.Call(scriptContext.Sender, addr, input[:], DefaultFoundryGasLimit, uint256.NewInt(0))
require.NoError(t, err, "call failed: %x", string(returnData))
require.EqualValues(t, expBroadcasts, broadcasts)
}
39 changes: 39 additions & 0 deletions op-chain-ops/script/testdata/scripts/ScriptExample.s.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ interface Vm {
function parseJsonKeys(string calldata json, string calldata key) external pure returns (string[] memory keys);
function startPrank(address msgSender) external;
function stopPrank() external;
function broadcast() external;
function startBroadcast() external;
function stopBroadcast() external;
}

// console is a minimal version of the console2 lib.
Expand Down Expand Up @@ -90,9 +93,45 @@ contract ScriptExample {
console.log("done!");
}

/// @notice example function, to test vm.broadcast with.
function runBroadcast() public {
console.log("testing single");
vm.broadcast();
this.call1("single_call1");
this.call2("single_call2");

console.log("testing start/stop");
vm.startBroadcast();
this.call1("startstop_call1");
this.call2("startstop_call2");
vm.stopBroadcast();
this.call1("startstop_call3");

console.log("testing nested");
vm.startBroadcast();
this.nested1("nested");
vm.stopBroadcast();
}

/// @notice example external function, to force a CALL, and test vm.startPrank with.
function hello(string calldata _v) external view {
console.log(_v);
console.log("hello msg.sender", address(msg.sender));
}

function call1(string calldata _v) external pure {
console.log(_v);
}

function call2(string calldata _v) external pure {
console.log(_v);
}

function nested1(string calldata _v) external view {
this.nested2(_v);
}

function nested2(string calldata _v) external pure {
console.log(_v);
}
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.