From 9d0653a12303841b82c7e0e82f593bcb5089e7ff Mon Sep 17 00:00:00 2001 From: pr0n00gler Date: Mon, 26 Feb 2024 14:58:23 +0200 Subject: [PATCH] bank hooks --- x/bank/app_test.go | 121 ++++++++++++++++++++++ x/bank/keeper/hooks.go | 26 +++++ x/bank/keeper/internal_unsafe.go | 11 ++ x/bank/keeper/keeper.go | 21 +++- x/bank/keeper/send.go | 34 +++++- x/bank/testutil/expected_keepers_mocks.go | 49 +++++++++ x/bank/types/expected_keepers.go | 12 +++ x/bank/types/hooks.go | 33 ++++++ 8 files changed, 303 insertions(+), 4 deletions(-) create mode 100644 x/bank/keeper/hooks.go create mode 100644 x/bank/keeper/internal_unsafe.go create mode 100644 x/bank/types/hooks.go diff --git a/x/bank/app_test.go b/x/bank/app_test.go index 2099b5981f83..67746b96f19b 100644 --- a/x/bank/app_test.go +++ b/x/bank/app_test.go @@ -1,10 +1,13 @@ package bank_test import ( + "context" + "fmt" "testing" abci "github.com/cometbft/cometbft/abci/types" cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -447,3 +450,121 @@ func TestMsgSetSendEnabled(t *testing.T) { }) } } + +var _ types.BankHooks = &MockBankHooksReceiver{} + +// BankHooks event hooks for bank (noalias) +type MockBankHooksReceiver struct{} + +// Mock BlockBeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom. +func (h *MockBankHooksReceiver) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { + for _, coin := range amount { + if coin.Amount.Equal(sdkmath.NewInt(100)) { + return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount) + } + } + return nil +} + +// variable for counting `TrackBeforeSend` +var ( + countTrackBeforeSend = 0 + expNextCount = 1 +) + +// Mock TrackBeforeSend bank hook that simply tracks the sending of exactly 50 coins of any denom. +func (h *MockBankHooksReceiver) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { + for _, coin := range amount { + if coin.Amount.Equal(sdkmath.NewInt(50)) { + countTrackBeforeSend += 1 + } + } +} + +func TestHooks(t *testing.T) { + acc1 := authtypes.NewBaseAccountWithAddress(addr1) + + genAccs := []authtypes.GenesisAccount{acc1} + s := createTestSuite(t, genAccs) + + ctx := s.App.BaseApp.NewContext(false) + + require.NoError(t, testutil.FundAccount(ctx, s.BankKeeper, addr1, sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000))))) + require.NoError(t, testutil.FundModuleAccount(ctx, s.BankKeeper, stakingtypes.BondedPoolName, sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000))))) + + // create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins + validSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1))) + triggerTrackSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(50))) + invalidBlockSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(100))) + + // setup our mock bank hooks receiver that prevents the send of 100 coins + bankHooksReceiver := MockBankHooksReceiver{} + baseBankKeeper, ok := s.BankKeeper.(bankkeeper.BaseKeeper) + require.True(t, ok) + bankkeeper.UnsafeSetHooks( + &baseBankKeeper, types.NewMultiBankHooks(&bankHooksReceiver), + ) + s.BankKeeper = baseBankKeeper + + // try sending a validSendAmount and it should work + err := s.BankKeeper.SendCoins(ctx, addr1, addr2, validSendAmount) + require.NoError(t, err) + + // try sending an trigger track send amount and it should work + err = s.BankKeeper.SendCoins(ctx, addr1, addr2, triggerTrackSendAmount) + require.NoError(t, err) + + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + + // try sending an invalidSendAmount and it should not work + err = s.BankKeeper.SendCoins(ctx, addr1, addr2, invalidBlockSendAmount) + require.Error(t, err) + + // make sure that account to module doesn't bypass hook + err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, validSendAmount) + require.NoError(t, err) + err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, invalidBlockSendAmount) + require.Error(t, err) + err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + + // make sure that module to account doesn't bypass hook + err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, validSendAmount) + require.NoError(t, err) + err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, invalidBlockSendAmount) + require.Error(t, err) + err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + + // make sure that module to module doesn't bypass hook + err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount) + require.NoError(t, err) + err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidBlockSendAmount) + // there should be no error since module to module does not call block before send hooks + require.NoError(t, err) + err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + + // make sure that DelegateCoins doesn't bypass the hook + err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount) + require.NoError(t, err) + err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidBlockSendAmount) + require.Error(t, err) + err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + + // make sure that UndelegateCoins doesn't bypass the hook + err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, validSendAmount) + require.NoError(t, err) + err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, invalidBlockSendAmount) + require.Error(t, err) + + err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ +} diff --git a/x/bank/keeper/hooks.go b/x/bank/keeper/hooks.go new file mode 100644 index 000000000000..00f1d7ef9c84 --- /dev/null +++ b/x/bank/keeper/hooks.go @@ -0,0 +1,26 @@ +package keeper + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/bank/types" +) + +// Implements StakingHooks interface +var _ types.BankHooks = BaseSendKeeper{} + +// TrackBeforeSend executes the TrackBeforeSend hook if registered. +func (k BaseSendKeeper) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { + if k.hooks != nil { + k.hooks.TrackBeforeSend(ctx, from, to, amount) + } +} + +// BlockBeforeSend executes the BlockBeforeSend hook if registered. +func (k BaseSendKeeper) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { + if k.hooks != nil { + return k.hooks.BlockBeforeSend(ctx, from, to, amount) + } + return nil +} diff --git a/x/bank/keeper/internal_unsafe.go b/x/bank/keeper/internal_unsafe.go new file mode 100644 index 000000000000..3408248f1a98 --- /dev/null +++ b/x/bank/keeper/internal_unsafe.go @@ -0,0 +1,11 @@ +package keeper + +import "github.com/cosmos/cosmos-sdk/x/bank/types" + +// UnsafeSetHooks updates the x/bank keeper's hooks, overriding any potential +// pre-existing hooks. +// +// WARNING: this function should only be used in tests. +func UnsafeSetHooks(k *BaseKeeper, h types.BankHooks) { + k.hooks = h +} diff --git a/x/bank/keeper/keeper.go b/x/bank/keeper/keeper.go index bfa45d23f64e..431b33b5d0cf 100644 --- a/x/bank/keeper/keeper.go +++ b/x/bank/keeper/keeper.go @@ -131,6 +131,13 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String()) } + err := k.BlockBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt) + if err != nil { + return err + } + // call the TrackBeforeSend hooks and the BlockBeforeSend hooks + k.TrackBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt) + balances := sdk.NewCoins() for _, coin := range amt { @@ -157,7 +164,7 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA types.NewCoinSpentEvent(delegatorAddr, amt), ) - err := k.addCoins(ctx, moduleAccAddr, amt) + err = k.addCoins(ctx, moduleAccAddr, amt) if err != nil { return err } @@ -180,7 +187,15 @@ func (k BaseKeeper) UndelegateCoins(ctx context.Context, moduleAccAddr, delegato return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String()) } - err := k.subUnlockedCoins(ctx, moduleAccAddr, amt) + // call the TrackBeforeSend hooks and the BlockBeforeSend hooks + err := k.BlockBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt) + if err != nil { + return err + } + + k.TrackBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt) + + err = k.subUnlockedCoins(ctx, moduleAccAddr, amt) if err != nil { return err } @@ -286,7 +301,7 @@ func (k BaseKeeper) SendCoinsFromModuleToModule( panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule)) } - return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt) + return k.SendCoinsWithoutBlockHook(ctx, senderAddr, recipientAcc.GetAddress(), amt) } // SendCoinsFromAccountToModule transfers coins from an AccAddress to a ModuleAccount. diff --git a/x/bank/keeper/send.go b/x/bank/keeper/send.go index 3bcc6d315c33..edf05ef50b6f 100644 --- a/x/bank/keeper/send.go +++ b/x/bank/keeper/send.go @@ -60,6 +60,7 @@ type BaseSendKeeper struct { ak types.AccountKeeper storeService store.KVStoreService logger log.Logger + hooks types.BankHooks // list of addresses that are restricted from receiving transactions blockedAddrs map[string]bool @@ -95,6 +96,17 @@ func NewBaseSendKeeper( } } +// SetHooks Set the bank hooks +func (k BaseSendKeeper) SetHooks(bh types.BankHooks) BaseSendKeeper { + if k.hooks != nil { + panic("cannot set bank hooks twice") + } + + k.hooks = bh + + return k +} + // AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. func (k BaseSendKeeper) AppendSendRestriction(restriction types.SendRestrictionFn) { k.sendRestriction.append(restriction) @@ -203,9 +215,29 @@ func (k BaseSendKeeper) InputOutputCoins(ctx context.Context, input types.Input, return nil } +// SendCoinsWithoutBlockHook calls sendCoins without calling the `BlockBeforeSend` hook. +func (k BaseSendKeeper) SendCoinsWithoutBlockHook(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + return k.sendCoins(ctx, fromAddr, toAddr, amt) +} + // SendCoins transfers amt coins from a sending account to a receiving account. // An error is returned upon failure. -func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error { +func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + // BlockBeforeSend hook should always be called before the TrackBeforeSend hook. + err := k.BlockBeforeSend(ctx, fromAddr, toAddr, amt) + if err != nil { + return err + } + + return k.sendCoins(ctx, fromAddr, toAddr, amt) +} + +// sendCoins transfers amt coins from a sending account to a receiving account. +// An error is returned upon failure. +func (k BaseSendKeeper) sendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error { + // call the TrackBeforeSend hooks + k.TrackBeforeSend(ctx, fromAddr, toAddr, amt) + var err error err = k.subUnlockedCoins(ctx, fromAddr, amt) if err != nil { diff --git a/x/bank/testutil/expected_keepers_mocks.go b/x/bank/testutil/expected_keepers_mocks.go index fcdfd8a472e6..b8df326a09b6 100644 --- a/x/bank/testutil/expected_keepers_mocks.go +++ b/x/bank/testutil/expected_keepers_mocks.go @@ -242,3 +242,52 @@ func (mr *MockAccountKeeperMockRecorder) ValidatePermissions(macc interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatePermissions", reflect.TypeOf((*MockAccountKeeper)(nil).ValidatePermissions), macc) } + +// MockBankHooks is a mock of BankHooks interface. +type MockBankHooks struct { + ctrl *gomock.Controller + recorder *MockBankHooksMockRecorder +} + +// MockBankHooksMockRecorder is the mock recorder for MockBankHooks. +type MockBankHooksMockRecorder struct { + mock *MockBankHooks +} + +// NewMockBankHooks creates a new mock instance. +func NewMockBankHooks(ctrl *gomock.Controller) *MockBankHooks { + mock := &MockBankHooks{ctrl: ctrl} + mock.recorder = &MockBankHooksMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBankHooks) EXPECT() *MockBankHooksMockRecorder { + return m.recorder +} + +// BlockBeforeSend mocks base method. +func (m *MockBankHooks) BlockBeforeSend(ctx context.Context, from, to types.AccAddress, amount types.Coins) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BlockBeforeSend", ctx, from, to, amount) + ret0, _ := ret[0].(error) + return ret0 +} + +// BlockBeforeSend indicates an expected call of BlockBeforeSend. +func (mr *MockBankHooksMockRecorder) BlockBeforeSend(ctx, from, to, amount interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BlockBeforeSend", reflect.TypeOf((*MockBankHooks)(nil).BlockBeforeSend), ctx, from, to, amount) +} + +// TrackBeforeSend mocks base method. +func (m *MockBankHooks) TrackBeforeSend(ctx context.Context, from, to types.AccAddress, amount types.Coins) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TrackBeforeSend", ctx, from, to, amount) +} + +// TrackBeforeSend indicates an expected call of TrackBeforeSend. +func (mr *MockBankHooksMockRecorder) TrackBeforeSend(ctx, from, to, amount interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrackBeforeSend", reflect.TypeOf((*MockBankHooks)(nil).TrackBeforeSend), ctx, from, to, amount) +} diff --git a/x/bank/types/expected_keepers.go b/x/bank/types/expected_keepers.go index ebd5eb9a657e..ef251be18aeb 100644 --- a/x/bank/types/expected_keepers.go +++ b/x/bank/types/expected_keepers.go @@ -33,3 +33,15 @@ type AccountKeeper interface { SetModuleAccount(ctx context.Context, macc sdk.ModuleAccountI) GetModulePermissions() map[string]types.PermissionsForAddress } + +// Event Hooks +// These can be utilized to communicate between a bank keeper and another +// keeper which must take particular actions when sends happen. +// The second keeper must implement this interface, which then the +// bank keeper can call. + +// BankHooks event hooks for bank sends +type BankHooks interface { + TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) // Must be before any send is executed + BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed +} diff --git a/x/bank/types/hooks.go b/x/bank/types/hooks.go new file mode 100644 index 000000000000..c58171c5e9f9 --- /dev/null +++ b/x/bank/types/hooks.go @@ -0,0 +1,33 @@ +package types + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// MultiBankHooks combine multiple bank hooks, all hook functions are run in array sequence +type MultiBankHooks []BankHooks + +// NewMultiBankHooks takes a list of BankHooks and returns a MultiBankHooks +func NewMultiBankHooks(hooks ...BankHooks) MultiBankHooks { + return hooks +} + +// TrackBeforeSend runs the TrackBeforeSend hooks in order for each BankHook in a MultiBankHooks struct +func (h MultiBankHooks) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { + for i := range h { + h[i].TrackBeforeSend(ctx, from, to, amount) + } +} + +// BlockBeforeSend runs the BlockBeforeSend hooks in order for each BankHook in a MultiBankHooks struct +func (h MultiBankHooks) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { + for i := range h { + err := h[i].BlockBeforeSend(ctx, from, to, amount) + if err != nil { + return err + } + } + return nil +}