diff --git a/x/bank/keeper/hooks.go b/x/bank/keeper/hooks.go index c0006c94da93..00f1d7ef9c84 100644 --- a/x/bank/keeper/hooks.go +++ b/x/bank/keeper/hooks.go @@ -10,10 +10,17 @@ import ( // Implements StakingHooks interface var _ types.BankHooks = BaseSendKeeper{} -// BeforeSend executes the BeforeSend hook if registered. -func (k BaseSendKeeper) BeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { +// TrackBeforeSend executes the TrackBeforeSend hook if registered. +func (k BaseSendKeeper) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { if k.hooks != nil { - return k.hooks.BeforeSend(ctx, from, to, amount) + k.hooks.TrackBeforeSend(ctx, from, to, amount) + } +} + +// BlockBeforeSend executes the BlockBeforeSend hook if registered. +func (k BaseSendKeeper) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { + if k.hooks != nil { + return k.hooks.BlockBeforeSend(ctx, from, to, amount) } return nil } diff --git a/x/bank/keeper/hooks_test.go b/x/bank/keeper/hooks_test.go index 25a17c5d30e8..3164ef603744 100644 --- a/x/bank/keeper/hooks_test.go +++ b/x/bank/keeper/hooks_test.go @@ -77,16 +77,30 @@ func createTestSuite(t *testing.T, genesisAccounts []authtypes.GenesisAccount) t type MockBankHooksReceiver struct{} // Mock BeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom. -func (h *MockBankHooksReceiver) BeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { +func (h *MockBankHooksReceiver) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { for _, coin := range amount { if coin.Amount.Equal(math.NewInt(100)) { return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount) } } - return nil } +// variable for counting `TrackBeforeSend` +var ( + countTrackBeforeSend = 0 + expNextCount = 1 +) + +// Mock TrackBeforeSend bank hook that simply tracks the sending of exactly 50 coins of any denom. +func (h *MockBankHooksReceiver) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { + for _, coin := range amount { + if coin.Amount.Equal(math.NewInt(50)) { + countTrackBeforeSend += 1 + } + } +} + func TestHooks(t *testing.T) { acc := &authtypes.BaseAccount{ Address: addr1.String(), @@ -104,7 +118,8 @@ func TestHooks(t *testing.T) { // create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins validSendAmount := sdk.NewCoins(sdk.NewCoin(bondDenom, math.NewInt(1))) - invalidSendAmount := sdk.NewCoins(sdk.NewCoin(bondDenom, math.NewInt(100))) + triggerTrackSendAmount := sdk.NewCoins(sdk.NewCoin(bondDenom, math.NewInt(50))) + invalidBlockSendAmount := sdk.NewCoins(sdk.NewCoin(bondDenom, math.NewInt(100))) // setup our mock bank hooks receiver that prevents the send of 100 coins bankHooksReceiver := MockBankHooksReceiver{} @@ -119,47 +134,78 @@ func TestHooks(t *testing.T) { err = app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], validSendAmount) require.NoError(t, err) + // try sending an trigger track send amount and it should work + err = app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], triggerTrackSendAmount) + require.NoError(t, err) + + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + // try sending an invalidSendAmount and it should not work - err = app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], invalidSendAmount) + err = app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], invalidBlockSendAmount) require.Error(t, err) // try doing SendManyCoins and make sure if even a single subsend is invalid, the entire function fails - err = app.BankKeeper.SendManyCoins(ctx, addrs[0], []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{invalidSendAmount, validSendAmount}) + err = app.BankKeeper.SendManyCoins(ctx, addrs[0], []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{invalidBlockSendAmount, validSendAmount}) require.Error(t, err) + err = app.BankKeeper.SendManyCoins(ctx, addrs[0], []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{triggerTrackSendAmount, validSendAmount}) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ + // make sure that account to module doesn't bypass hook err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, validSendAmount) require.NoError(t, err) - err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, invalidSendAmount) + err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, invalidBlockSendAmount) require.Error(t, err) + err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ // make sure that module to account doesn't bypass hook err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], validSendAmount) require.NoError(t, err) - err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], invalidSendAmount) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], invalidBlockSendAmount) require.Error(t, err) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ // make sure that module to module doesn't bypass hook err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount) require.NoError(t, err) - err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidSendAmount) - require.Error(t, err) + err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidBlockSendAmount) + // there should be no error since module to module does not call block before send hooks + require.NoError(t, err) + err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ // make sure that module to many accounts doesn't bypass hook err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, validSendAmount}) require.NoError(t, err) - err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, invalidSendAmount}) + err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, invalidBlockSendAmount}) require.Error(t, err) + err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, triggerTrackSendAmount}) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ // make sure that DelegateCoins doesn't bypass the hook err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount) require.NoError(t, err) - err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidSendAmount) + err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidBlockSendAmount) require.Error(t, err) + err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ // make sure that UndelegateCoins doesn't bypass the hook err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], validSendAmount) require.NoError(t, err) - err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], invalidSendAmount) + err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], invalidBlockSendAmount) require.Error(t, err) + + err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], triggerTrackSendAmount) + require.Equal(t, countTrackBeforeSend, expNextCount) + expNextCount++ } diff --git a/x/bank/keeper/keeper.go b/x/bank/keeper/keeper.go index 45d6dec04c79..701a5fdfbd67 100644 --- a/x/bank/keeper/keeper.go +++ b/x/bank/keeper/keeper.go @@ -146,11 +146,12 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String()) } - // call the BeforeSend hooks - err := k.BeforeSend(ctx, delegatorAddr, moduleAccAddr, amt) + err := k.BlockBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt) if err != nil { return err } + // call the TrackBeforeSend hooks and the BlockBeforeSend hooks + k.TrackBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt) balances := sdk.NewCoins() @@ -201,12 +202,14 @@ func (k BaseKeeper) UndelegateCoins(ctx context.Context, moduleAccAddr, delegato return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String()) } - // call the BeforeSend hooks - err := k.BeforeSend(ctx, moduleAccAddr, delegatorAddr, amt) + // call the TrackBeforeSend hooks and the BlockBeforeSend hooks + err := k.BlockBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt) if err != nil { return err } + k.TrackBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt) + err = k.subUnlockedCoins(ctx, moduleAccAddr, amt) if err != nil { return err @@ -324,6 +327,8 @@ func (k BaseKeeper) SendCoinsFromModuleToManyAccounts( // SendCoinsFromModuleToModule transfers coins from a ModuleAccount to another. // It will panic if either module account does not exist. +// SendCoinsFromModuleToModule is the only send method that does not call both BlockBeforeSend and TrackBeforeSend hook. +// It only calls the TrackBeforeSend hook. func (k BaseKeeper) SendCoinsFromModuleToModule( ctx context.Context, senderModule, recipientModule string, amt sdk.Coins, ) error { @@ -337,7 +342,7 @@ func (k BaseKeeper) SendCoinsFromModuleToModule( panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule)) } - return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt) + return k.SendCoinsWithoutBlockHook(ctx, senderAddr, recipientAcc.GetAddress(), amt) } // SendCoinsFromAccountToModule transfers coins from an AccAddress to a ModuleAccount. diff --git a/x/bank/keeper/send.go b/x/bank/keeper/send.go index 7e17b057c999..4368bc007941 100644 --- a/x/bank/keeper/send.go +++ b/x/bank/keeper/send.go @@ -216,17 +216,29 @@ func (k BaseSendKeeper) InputOutputCoins(ctx context.Context, input types.Input, return nil } +// SendCoinsWithoutBlockHook calls sendCoins without calling the `BlockBeforeSend` hook. +func (k BaseSendKeeper) SendCoinsWithoutBlockHook(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + return k.sendCoins(ctx, fromAddr, toAddr, amt) +} + // SendCoins transfers amt coins from a sending account to a receiving account. // An error is returned upon failure. func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { - // call the BeforeSend hooks - sdkCtx := sdk.UnwrapSDKContext(ctx) - err := k.BeforeSend(sdkCtx, fromAddr, toAddr, amt) + // BlockBeforeSend hook should always be called before the TrackBeforeSend hook. + err := k.BlockBeforeSend(ctx, fromAddr, toAddr, amt) if err != nil { return err } - err = k.subUnlockedCoins(ctx, fromAddr, amt) + return k.sendCoins(ctx, fromAddr, toAddr, amt) +} + +// sendCoins has the internal logic for sending coins. +func (k BaseSendKeeper) sendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + // call the TrackBeforeSend hooks + k.TrackBeforeSend(ctx, fromAddr, toAddr, amt) + + err := k.subUnlockedCoins(ctx, fromAddr, amt) if err != nil { return err } @@ -250,7 +262,7 @@ func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, // bech32 encoding is expensive! Only do it once for fromAddr fromAddrString := fromAddr.String() - sdkCtx = sdk.UnwrapSDKContext(ctx) + sdkCtx := sdk.UnwrapSDKContext(ctx) sdkCtx.EventManager().EmitEvents(sdk.Events{ sdk.NewEvent( types.EventTypeTransfer, @@ -276,12 +288,14 @@ func (k BaseSendKeeper) SendManyCoins(ctx context.Context, fromAddr sdk.AccAddre totalAmt := sdk.Coins{} for i, amt := range amts { - sdkCtx := sdk.UnwrapSDKContext(ctx) // make sure to trigger the BeforeSend hooks for all the sends that are about to occur - err := k.BeforeSend(sdkCtx, fromAddr, toAddrs[i], amts[i]) + k.TrackBeforeSend(ctx, fromAddr, toAddrs[i], amts[i]) + + err := k.BlockBeforeSend(ctx, fromAddr, toAddrs[i], amts[i]) if err != nil { return err } + totalAmt = sdk.Coins.Add(totalAmt, amt...) } diff --git a/x/bank/testutil/expected_keepers_mocks.go b/x/bank/testutil/expected_keepers_mocks.go index 199a07a3f98b..b8df326a09b6 100644 --- a/x/bank/testutil/expected_keepers_mocks.go +++ b/x/bank/testutil/expected_keepers_mocks.go @@ -266,16 +266,28 @@ func (m *MockBankHooks) EXPECT() *MockBankHooksMockRecorder { return m.recorder } -// BeforeSend mocks base method. -func (m *MockBankHooks) BeforeSend(ctx context.Context, from, to types.AccAddress, amount types.Coins) error { +// BlockBeforeSend mocks base method. +func (m *MockBankHooks) BlockBeforeSend(ctx context.Context, from, to types.AccAddress, amount types.Coins) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeforeSend", ctx, from, to, amount) + ret := m.ctrl.Call(m, "BlockBeforeSend", ctx, from, to, amount) ret0, _ := ret[0].(error) return ret0 } -// BeforeSend indicates an expected call of BeforeSend. -func (mr *MockBankHooksMockRecorder) BeforeSend(ctx, from, to, amount interface{}) *gomock.Call { +// BlockBeforeSend indicates an expected call of BlockBeforeSend. +func (mr *MockBankHooksMockRecorder) BlockBeforeSend(ctx, from, to, amount interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeSend", reflect.TypeOf((*MockBankHooks)(nil).BeforeSend), ctx, from, to, amount) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BlockBeforeSend", reflect.TypeOf((*MockBankHooks)(nil).BlockBeforeSend), ctx, from, to, amount) +} + +// TrackBeforeSend mocks base method. +func (m *MockBankHooks) TrackBeforeSend(ctx context.Context, from, to types.AccAddress, amount types.Coins) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TrackBeforeSend", ctx, from, to, amount) +} + +// TrackBeforeSend indicates an expected call of TrackBeforeSend. +func (mr *MockBankHooksMockRecorder) TrackBeforeSend(ctx, from, to, amount interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrackBeforeSend", reflect.TypeOf((*MockBankHooks)(nil).TrackBeforeSend), ctx, from, to, amount) } diff --git a/x/bank/types/expected_keepers.go b/x/bank/types/expected_keepers.go index 099b5edb704f..ef251be18aeb 100644 --- a/x/bank/types/expected_keepers.go +++ b/x/bank/types/expected_keepers.go @@ -42,5 +42,6 @@ type AccountKeeper interface { // BankHooks event hooks for bank sends type BankHooks interface { - BeforeSend(ctx context.Context, from sdk.AccAddress, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed + TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) // Must be before any send is executed + BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed } diff --git a/x/bank/types/hooks.go b/x/bank/types/hooks.go index 05a269b00fbc..6f76208ef1d0 100644 --- a/x/bank/types/hooks.go +++ b/x/bank/types/hooks.go @@ -14,10 +14,17 @@ func NewMultiBankHooks(hooks ...BankHooks) MultiBankHooks { return hooks } -// BeforeSend runs the BeforeSend hooks in order for each BankHook in a MultiBankHooks struct -func (h MultiBankHooks) BeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { +// TrackBeforeSend runs the TrackBeforeSend hooks in order for each BankHook in a MultiBankHooks struct +func (h MultiBankHooks) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) { for i := range h { - err := h[i].BeforeSend(ctx, from, to, amount) + h[i].TrackBeforeSend(ctx, from, to, amount) + } +} + +// BlockBeforeSend runs the BlockBeforeSend hooks in order for each BankHook in a MultiBankHooks struct +func (h MultiBankHooks) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error { + for i := range h { + err := h[i].BlockBeforeSend(ctx, from, to, amount) if err != nil { return err }