Skip to content

Commit

Permalink
feat: add a BeforeSend hook to the bank module (#278)
Browse files Browse the repository at this point in the history
* in progress

* add bank hooks:

* Remove stale comments

* add nil guards

* add hooks function

* Apply suggestions from code review

Co-authored-by: Roman <[email protected]>

* add tests

* Apply suggestions from code review

Co-authored-by: Roman <[email protected]>

* Apply suggestions from code review

Co-authored-by: Aleksandr Bezobchuk <[email protected]>

* lint

Co-authored-by: Roman <[email protected]>
Co-authored-by: Aleksandr Bezobchuk <[email protected]>
  • Loading branch information
3 people authored and czarcas7ic committed May 9, 2024
1 parent 3a4d0fc commit 058afdd
Show file tree
Hide file tree
Showing 9 changed files with 354 additions and 6 deletions.
110 changes: 110 additions & 0 deletions x/bank/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package bank_test

import (
"context"
"fmt"
"testing"

"cosmossdk.io/math"
tmproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/stretchr/testify/require"

sdk "github.com/cosmos/cosmos-sdk/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
"github.com/cosmos/cosmos-sdk/x/bank/keeper"
banktestutil "github.com/cosmos/cosmos-sdk/x/bank/testutil"
"github.com/cosmos/cosmos-sdk/x/bank/types"

stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
)

var _ types.BankHooks = &MockBankHooksReceiver{}

// BankHooks event hooks for bank (noalias)
type MockBankHooksReceiver struct{}

// Mock BeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom.
func (h *MockBankHooksReceiver) BeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for _, coin := range amount {
if coin.Amount.Equal(math.NewInt(100)) {
return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount)
}
}
return nil
}

func TestHooks(t *testing.T) {
acc := &authtypes.BaseAccount{
Address: addr1.String(),
}

genAccs := []authtypes.GenesisAccount{acc}
app := createTestSuite(t, genAccs)
baseApp := app.App.BaseApp
ctx := baseApp.NewContextLegacy(false, tmproto.Header{})

require.NoError(t, banktestutil.FundAccount(ctx, app.BankKeeper, addr1, sdk.NewCoins(sdk.NewInt64Coin(sdk.DefaultBondDenom, 10000))))
require.NoError(t, banktestutil.FundAccount(ctx, app.BankKeeper, addr2, sdk.NewCoins(sdk.NewInt64Coin(sdk.DefaultBondDenom, 10000))))
banktestutil.FundModuleAccount(ctx, app.BankKeeper, stakingtypes.BondedPoolName, sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000))))

// create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins
validSendAmount := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1)))
invalidSendAmount := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(100)))

// setup our mock bank hooks receiver that prevents the send of 100 coins
bankHooksReceiver := MockBankHooksReceiver{}
baseBankKeeper, ok := app.BankKeeper.(keeper.BaseKeeper)
require.True(t, ok)
baseBankKeeper.SetHooks(
types.NewMultiBankHooks(&bankHooksReceiver),
)
app.BankKeeper = baseBankKeeper

// try sending a validSendAmount and it should work
err := app.BankKeeper.SendCoins(ctx, addr1, addr2, validSendAmount)
require.NoError(t, err)

// try sending an invalidSendAmount and it should not work
err = app.BankKeeper.SendCoins(ctx, addr1, addr2, invalidSendAmount)
require.Error(t, err)

// try doing SendManyCoins and make sure if even a single subsend is invalid, the entire function fails
err = app.BankKeeper.SendManyCoins(ctx, addr1, []sdk.AccAddress{addr1, addr2}, []sdk.Coins{invalidSendAmount, validSendAmount})
require.Error(t, err)

// make sure that account to module doesn't bypass hook
err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, invalidSendAmount)
require.Error(t, err)

// make sure that module to account doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, invalidSendAmount)
require.Error(t, err)

// make sure that module to module doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidSendAmount)
require.Error(t, err)

// make sure that module to many accounts doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addr1, addr2}, []sdk.Coins{validSendAmount, validSendAmount})
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addr1, addr2}, []sdk.Coins{validSendAmount, invalidSendAmount})
require.Error(t, err)

// make sure that DelegateCoins doesn't bypass the hook
err = app.BankKeeper.DelegateCoins(ctx, addr1, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.DelegateCoins(ctx, addr1, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidSendAmount)
require.Error(t, err)

// make sure that UndelegateCoins doesn't bypass the hook
err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, invalidSendAmount)
require.Error(t, err)
}
19 changes: 19 additions & 0 deletions x/bank/keeper/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package keeper

import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/bank/types"
)

// Implements StakingHooks interface
var _ types.BankHooks = BaseSendKeeper{}

// BeforeSend executes the BeforeSend hook if registered.
func (k BaseSendKeeper) BeforeSend(ctx context.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
if k.hooks != nil {
return k.hooks.BeforeSend(ctx, from, to, amount)
}
return nil
}
11 changes: 11 additions & 0 deletions x/bank/keeper/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package keeper

import "github.com/cosmos/cosmos-sdk/x/bank/types"

// UnsafeSetHooks updates the x/bank keeper's hooks, overriding any potential
// pre-existing hooks.
//
// WARNING: this function should only be used in tests.
func UnsafeSetHooks(k *BaseKeeper, h types.BankHooks) {
k.hooks = h
}
43 changes: 38 additions & 5 deletions x/bank/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Keeper interface {
UndelegateCoinsFromModuleToAccount(ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error
MintCoins(ctx context.Context, moduleName string, amt sdk.Coins) error
BurnCoins(ctx context.Context, moduleName string, amt sdk.Coins) error
SendCoinsFromModuleToManyAccounts(ctx context.Context, senderModule string, recipientAddrs []sdk.AccAddress, amts []sdk.Coins) error

DelegateCoins(ctx context.Context, delegatorAddr, moduleAccAddr sdk.AccAddress, amt sdk.Coins) error
UndelegateCoins(ctx context.Context, moduleAccAddr, delegatorAddr sdk.AccAddress, amt sdk.Coins) error
Expand Down Expand Up @@ -143,6 +144,12 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA
return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

// call the BeforeSend hooks
err := k.BeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)
if err != nil {
return err
}

balances := sdk.NewCoins()

for _, coin := range amt {
Expand All @@ -169,7 +176,7 @@ func (k BaseKeeper) DelegateCoins(ctx context.Context, delegatorAddr, moduleAccA
types.NewCoinSpentEvent(delegatorAddr, amt),
)

err := k.addCoins(ctx, moduleAccAddr, amt)
err = k.addCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand All @@ -192,7 +199,13 @@ func (k BaseKeeper) UndelegateCoins(ctx context.Context, moduleAccAddr, delegato
return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.subUnlockedCoins(ctx, moduleAccAddr, amt)
// call the BeforeSend hooks
err := k.BeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)
if err != nil {
return err
}

err = k.subUnlockedCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand Down Expand Up @@ -268,9 +281,7 @@ func (k BaseKeeper) SetDenomMetaData(ctx context.Context, denomMetaData types.Me
// SendCoinsFromModuleToAccount transfers coins from a ModuleAccount to an AccAddress.
// It will panic if the module account does not exist. An error is returned if
// the recipient address is black-listed or if sending the tokens fails.
func (k BaseKeeper) SendCoinsFromModuleToAccount(
ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins,
) error {
func (k BaseKeeper) SendCoinsFromModuleToAccount(ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error {
senderAddr := k.ak.GetModuleAddress(senderModule)
if senderAddr == nil {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", senderModule))
Expand All @@ -283,6 +294,28 @@ func (k BaseKeeper) SendCoinsFromModuleToAccount(
return k.SendCoins(ctx, senderAddr, recipientAddr, amt)
}

// SendCoinsFromModuleToManyAccounts transfers coins from a ModuleAccount to multiple AccAddresses.
// It will panic if the module account does not exist. An error is returned if
// the recipient address is black-listed or if sending the tokens fails.
func (k BaseKeeper) SendCoinsFromModuleToManyAccounts(ctx context.Context, senderModule string, recipientAddrs []sdk.AccAddress, amts []sdk.Coins) error {
if len(recipientAddrs) != len(amts) {
panic(fmt.Errorf("addresses and amounts numbers does not match"))
}

senderAddr := k.ak.GetModuleAddress(senderModule)
if senderAddr == nil {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", senderModule))
}

for _, recipientAddr := range recipientAddrs {
if k.BlockedAddr(recipientAddr) {
return errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", recipientAddr)
}
}

return k.SendManyCoins(ctx, senderAddr, recipientAddrs, amts)
}

// SendCoinsFromModuleToModule transfers coins from a ModuleAccount to another.
// It will panic if either module account does not exist.
func (k BaseKeeper) SendCoinsFromModuleToModule(
Expand Down
75 changes: 74 additions & 1 deletion x/bank/keeper/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type SendKeeper interface {

InputOutputCoins(ctx context.Context, input types.Input, outputs []types.Output) error
SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error
SendManyCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddrs []sdk.AccAddress, amts []sdk.Coins) error

GetParams(ctx context.Context) types.Params
SetParams(ctx context.Context, params types.Params) error
Expand Down Expand Up @@ -60,6 +61,7 @@ type BaseSendKeeper struct {
ak types.AccountKeeper
storeService store.KVStoreService
logger log.Logger
hooks types.BankHooks

// list of addresses that are restricted from receiving transactions
blockedAddrs map[string]bool
Expand Down Expand Up @@ -115,6 +117,17 @@ func (k BaseSendKeeper) GetAuthority() string {
return k.authority
}

// Set the bank hooks
func (k *BaseSendKeeper) SetHooks(bh types.BankHooks) *BaseSendKeeper {
if k.hooks != nil {
panic("cannot set bank hooks twice")
}

k.hooks = bh

return k
}

// GetParams returns the total set of bank parameters.
func (k BaseSendKeeper) GetParams(ctx context.Context) (params types.Params) {
p, _ := k.Params.Get(ctx)
Expand Down Expand Up @@ -206,7 +219,12 @@ func (k BaseSendKeeper) InputOutputCoins(ctx context.Context, input types.Input,
// SendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error {
var err error
// call the BeforeSend hooks
err := k.BeforeSend(ctx, fromAddr, toAddr, amt)
if err != nil {
return err
}

err = k.subUnlockedCoins(ctx, fromAddr, amt)
if err != nil {
return err
Expand Down Expand Up @@ -251,6 +269,61 @@ func (k BaseSendKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccA
return nil
}

// SendManyCoins transfer multiple amt coins from a sending account to multiple receiving accounts.
// An error is returned upon failure.
func (k BaseSendKeeper) SendManyCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddrs []sdk.AccAddress, amts []sdk.Coins) error {
if len(toAddrs) != len(amts) {
return fmt.Errorf("addresses and amounts numbers does not match")
}

totalAmt := sdk.Coins{}
for i, amt := range amts {
// make sure to trigger the BeforeSend hooks for all the sends that are about to occur
err := k.BeforeSend(ctx, fromAddr, toAddrs[i], amts[i])
if err != nil {
return err
}
totalAmt = sdk.Coins.Add(totalAmt, amt...)
}

err := k.subUnlockedCoins(ctx, fromAddr, totalAmt)
if err != nil {
return err
}

fromAddrString := fromAddr.String()
sdkCtx := sdk.UnwrapSDKContext(ctx)
for i, toAddr := range toAddrs {
amt := amts[i]

err := k.addCoins(ctx, toAddr, amt)
if err != nil {
return err
}

acc := k.ak.GetAccount(ctx, toAddr)
if acc == nil {
defer telemetry.IncrCounter(1, "new", "account")
k.ak.SetAccount(ctx, k.ak.NewAccountWithAddress(ctx, toAddr))
}

sdkCtx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
types.EventTypeTransfer,
sdk.NewAttribute(types.AttributeKeyRecipient, toAddr.String()),
sdk.NewAttribute(types.AttributeKeySender, fromAddrString),
sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()),
),
sdk.NewEvent(
sdk.EventTypeMessage,
sdk.NewAttribute(types.AttributeKeySender, fromAddrString),
),
})
}

return nil
}

// subUnlockedCoins removes the unlocked amt coins of the given account. An error is
// returned if the resulting balance is negative or the initial amount is invalid.
// A coin_spent event is emitted after.
Expand Down
37 changes: 37 additions & 0 deletions x/bank/testutil/expected_keepers_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions x/bank/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ type AccountKeeper interface {
SetModuleAccount(ctx context.Context, macc sdk.ModuleAccountI)
GetModulePermissions() map[string]types.PermissionsForAddress
}

// Event Hooks
// These can be utilized to communicate between a bank keeper and another
// keeper which must take particular actions when sends happen.
// The second keeper must implement this interface, which then the
// bank keeper can call.

// BankHooks event hooks for bank sends
type BankHooks interface {
BeforeSend(ctx context.Context, from sdk.AccAddress, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed
}
Loading

0 comments on commit 058afdd

Please sign in to comment.