Skip to content

Commit

Permalink
bank hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
pr0n00gler committed Feb 26, 2024
1 parent 7dbed2f commit 9d0653a
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 4 deletions.
121 changes: 121 additions & 0 deletions x/bank/app_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package bank_test

import (
"context"
"fmt"
"testing"

abci "github.com/cometbft/cometbft/abci/types"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -447,3 +450,121 @@ func TestMsgSetSendEnabled(t *testing.T) {
})
}
}

var _ types.BankHooks = &MockBankHooksReceiver{}

// BankHooks event hooks for bank (noalias)
type MockBankHooksReceiver struct{}

// Mock BlockBeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom.
func (h *MockBankHooksReceiver) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for _, coin := range amount {
if coin.Amount.Equal(sdkmath.NewInt(100)) {
return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount)
}
}
return nil
}

// variable for counting `TrackBeforeSend`
var (
countTrackBeforeSend = 0
expNextCount = 1
)

// Mock TrackBeforeSend bank hook that simply tracks the sending of exactly 50 coins of any denom.
func (h *MockBankHooksReceiver) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) {
for _, coin := range amount {
if coin.Amount.Equal(sdkmath.NewInt(50)) {
countTrackBeforeSend += 1

Check failure on line 479 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

increment-decrement: should replace countTrackBeforeSend += 1 with countTrackBeforeSend++ (revive)
}
}
}

func TestHooks(t *testing.T) {
acc1 := authtypes.NewBaseAccountWithAddress(addr1)

genAccs := []authtypes.GenesisAccount{acc1}
s := createTestSuite(t, genAccs)

ctx := s.App.BaseApp.NewContext(false)

require.NoError(t, testutil.FundAccount(ctx, s.BankKeeper, addr1, sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000)))))
require.NoError(t, testutil.FundModuleAccount(ctx, s.BankKeeper, stakingtypes.BondedPoolName, sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000)))))

// create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins
validSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1)))
triggerTrackSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(50)))
invalidBlockSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(100)))

// setup our mock bank hooks receiver that prevents the send of 100 coins
bankHooksReceiver := MockBankHooksReceiver{}
baseBankKeeper, ok := s.BankKeeper.(bankkeeper.BaseKeeper)
require.True(t, ok)
bankkeeper.UnsafeSetHooks(
&baseBankKeeper, types.NewMultiBankHooks(&bankHooksReceiver),
)
s.BankKeeper = baseBankKeeper

// try sending a validSendAmount and it should work
err := s.BankKeeper.SendCoins(ctx, addr1, addr2, validSendAmount)
require.NoError(t, err)

// try sending an trigger track send amount and it should work
err = s.BankKeeper.SendCoins(ctx, addr1, addr2, triggerTrackSendAmount)
require.NoError(t, err)

require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// try sending an invalidSendAmount and it should not work
err = s.BankKeeper.SendCoins(ctx, addr1, addr2, invalidBlockSendAmount)
require.Error(t, err)

// make sure that account to module doesn't bypass hook
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, triggerTrackSendAmount)

Check failure on line 529 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

ineffectual assignment to err (ineffassign)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that module to account doesn't bypass hook
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, triggerTrackSendAmount)

Check failure on line 538 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

ineffectual assignment to err (ineffassign)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that module to module doesn't bypass hook
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidBlockSendAmount)
// there should be no error since module to module does not call block before send hooks
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, triggerTrackSendAmount)

Check failure on line 548 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

ineffectual assignment to err (ineffassign)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that DelegateCoins doesn't bypass the hook
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), triggerTrackSendAmount)

Check failure on line 557 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

ineffectual assignment to err (ineffassign)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that UndelegateCoins doesn't bypass the hook
err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, invalidBlockSendAmount)
require.Error(t, err)

err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, triggerTrackSendAmount)

Check failure on line 567 in x/bank/app_test.go

View workflow job for this annotation

GitHub Actions / Analyze

ineffectual assignment to err (ineffassign)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++
}
26 changes: 26 additions & 0 deletions x/bank/keeper/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package keeper

import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/bank/types"
)

// Implements StakingHooks interface
var _ types.BankHooks = BaseSendKeeper{}

// TrackBeforeSend executes the TrackBeforeSend hook if registered.
func (k BaseSendKeeper) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) {
if k.hooks != nil {
k.hooks.TrackBeforeSend(ctx, from, to, amount)
}
}

// BlockBeforeSend executes the BlockBeforeSend hook if registered.
func (k BaseSendKeeper) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
if k.hooks != nil {
return k.hooks.BlockBeforeSend(ctx, from, to, amount)
}
return nil
}
11 changes: 11 additions & 0 deletions x/bank/keeper/internal_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package keeper

import "github.com/cosmos/cosmos-sdk/x/bank/types"

// UnsafeSetHooks updates the x/bank keeper's hooks, overriding any potential
// pre-existing hooks.
//
// WARNING: this function should only be used in tests.
func UnsafeSetHooks(k *BaseKeeper, h types.BankHooks) {
k.hooks = h
}
21 changes: 18 additions & 3 deletions x/bank/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA
return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.BlockBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)
if err != nil {
return err
}
// call the TrackBeforeSend hooks and the BlockBeforeSend hooks
k.TrackBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)

balances := sdk.NewCoins()

for _, coin := range amt {
Expand All @@ -157,7 +164,7 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA
types.NewCoinSpentEvent(delegatorAddr, amt),
)

err := k.addCoins(ctx, moduleAccAddr, amt)
err = k.addCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand All @@ -180,7 +187,15 @@ func (k BaseKeeper) UndelegateCoins(ctx context.Context, moduleAccAddr, delegato
return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.subUnlockedCoins(ctx, moduleAccAddr, amt)
// call the TrackBeforeSend hooks and the BlockBeforeSend hooks
err := k.BlockBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)
if err != nil {
return err
}

k.TrackBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)

err = k.subUnlockedCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand Down Expand Up @@ -286,7 +301,7 @@ func (k BaseKeeper) SendCoinsFromModuleToModule(
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule))
}

return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
return k.SendCoinsWithoutBlockHook(ctx, senderAddr, recipientAcc.GetAddress(), amt)
}

// SendCoinsFromAccountToModule transfers coins from an AccAddress to a ModuleAccount.
Expand Down
34 changes: 33 additions & 1 deletion x/bank/keeper/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type BaseSendKeeper struct {
ak types.AccountKeeper
storeService store.KVStoreService
logger log.Logger
hooks types.BankHooks

// list of addresses that are restricted from receiving transactions
blockedAddrs map[string]bool
Expand Down Expand Up @@ -95,6 +96,17 @@ func NewBaseSendKeeper(
}
}

// SetHooks Set the bank hooks
func (k BaseSendKeeper) SetHooks(bh types.BankHooks) BaseSendKeeper {
if k.hooks != nil {
panic("cannot set bank hooks twice")
}

k.hooks = bh

return k
}

// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions.
func (k BaseSendKeeper) AppendSendRestriction(restriction types.SendRestrictionFn) {
k.sendRestriction.append(restriction)
Expand Down Expand Up @@ -203,9 +215,29 @@ func (k BaseSendKeeper) InputOutputCoins(ctx context.Context, input types.Input,
return nil
}

// SendCoinsWithoutBlockHook calls sendCoins without calling the `BlockBeforeSend` hook.
func (k BaseSendKeeper) SendCoinsWithoutBlockHook(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
return k.sendCoins(ctx, fromAddr, toAddr, amt)
}

// SendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error {
func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
// BlockBeforeSend hook should always be called before the TrackBeforeSend hook.
err := k.BlockBeforeSend(ctx, fromAddr, toAddr, amt)
if err != nil {
return err
}

return k.sendCoins(ctx, fromAddr, toAddr, amt)
}

// sendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) sendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error {
// call the TrackBeforeSend hooks
k.TrackBeforeSend(ctx, fromAddr, toAddr, amt)

var err error
err = k.subUnlockedCoins(ctx, fromAddr, amt)
if err != nil {
Expand Down
49 changes: 49 additions & 0 deletions x/bank/testutil/expected_keepers_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions x/bank/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ type AccountKeeper interface {
SetModuleAccount(ctx context.Context, macc sdk.ModuleAccountI)
GetModulePermissions() map[string]types.PermissionsForAddress
}

// Event Hooks
// These can be utilized to communicate between a bank keeper and another
// keeper which must take particular actions when sends happen.
// The second keeper must implement this interface, which then the
// bank keeper can call.

// BankHooks event hooks for bank sends
type BankHooks interface {
TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) // Must be before any send is executed
BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed
}
33 changes: 33 additions & 0 deletions x/bank/types/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package types

import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
)

// MultiBankHooks combine multiple bank hooks, all hook functions are run in array sequence
type MultiBankHooks []BankHooks

// NewMultiBankHooks takes a list of BankHooks and returns a MultiBankHooks
func NewMultiBankHooks(hooks ...BankHooks) MultiBankHooks {
return hooks
}

// TrackBeforeSend runs the TrackBeforeSend hooks in order for each BankHook in a MultiBankHooks struct
func (h MultiBankHooks) TrackBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) {
for i := range h {
h[i].TrackBeforeSend(ctx, from, to, amount)
}
}

// BlockBeforeSend runs the BlockBeforeSend hooks in order for each BankHook in a MultiBankHooks struct
func (h MultiBankHooks) BlockBeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for i := range h {
err := h[i].BlockBeforeSend(ctx, from, to, amount)
if err != nil {
return err
}
}
return nil
}

0 comments on commit 9d0653a

Please sign in to comment.