Skip to content

Commit

Permalink
feat: add a BeforeSend hook to the bank module (#278)
Browse files Browse the repository at this point in the history
* in progress

* add bank hooks:

* Remove stale comments

* add nil guards

* add hooks function

* Apply suggestions from code review

Co-authored-by: Roman <[email protected]>

* add tests

* Apply suggestions from code review

Co-authored-by: Roman <[email protected]>

* Apply suggestions from code review

Co-authored-by: Aleksandr Bezobchuk <[email protected]>

* lint

Co-authored-by: Roman <[email protected]>
Co-authored-by: Aleksandr Bezobchuk <[email protected]>
  • Loading branch information
3 people authored and czarcas7ic committed Oct 30, 2023
1 parent 05c92b5 commit b1b5c4a
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 9 deletions.
17 changes: 17 additions & 0 deletions x/bank/keeper/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package keeper

import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/bank/types"
)

// Implements StakingHooks interface
var _ types.BankHooks = BaseSendKeeper{}

// BeforeSend executes the BeforeSend hook if registered.
func (k BaseSendKeeper) BeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
if k.hooks != nil {
return k.hooks.BeforeSend(ctx, from, to, amount)
}
return nil
}
154 changes: 154 additions & 0 deletions x/bank/keeper/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package keeper_test

import (
"fmt"
"testing"

tmproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
"github.com/cosmos/cosmos-sdk/runtime"
"github.com/cosmos/cosmos-sdk/testutil/configurator"
simtestutil "github.com/cosmos/cosmos-sdk/testutil/sims"
"github.com/stretchr/testify/require"

sdk "github.com/cosmos/cosmos-sdk/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
"github.com/cosmos/cosmos-sdk/x/bank/keeper"
bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper"
banktestutil "github.com/cosmos/cosmos-sdk/x/bank/testutil"
"github.com/cosmos/cosmos-sdk/x/bank/types"
stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper"

stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
)

var _ types.BankHooks = &MockBankHooksReceiver{}

var (
priv1 = secp256k1.GenPrivKey()
addr1 = sdk.AccAddress(priv1.PubKey().Address())
)

type testingSuite struct {
BankKeeper bankkeeper.Keeper
AccountKeeper types.AccountKeeper
StakingKeeper stakingkeeper.Keeper
App *runtime.App
}

func createTestSuite(t *testing.T, genesisAccounts []authtypes.GenesisAccount) testingSuite {
res := testingSuite{}

var genAccounts []simtestutil.GenesisAccount
for _, acc := range genesisAccounts {
genAccounts = append(genAccounts, simtestutil.GenesisAccount{GenesisAccount: acc})
}

startupCfg := simtestutil.DefaultStartUpConfig()
startupCfg.GenesisAccounts = genAccounts

app, err := simtestutil.SetupWithConfiguration(configurator.NewAppConfig(
configurator.ParamsModule(),
configurator.AuthModule(),
configurator.StakingModule(),
configurator.TxModule(),
configurator.ConsensusModule(),
configurator.BankModule(),
configurator.GovModule(),
),
startupCfg, &res.BankKeeper, &res.AccountKeeper, &res.StakingKeeper)

res.App = app

require.NoError(t, err)
return res
}

// BankHooks event hooks for bank (noalias)
type MockBankHooksReceiver struct{}

// Mock BeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom.
func (h *MockBankHooksReceiver) BeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for _, coin := range amount {
if coin.Amount.Equal(sdk.NewInt(100)) {
return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount)
}
}

return nil
}

func TestHooks(t *testing.T) {
acc := &authtypes.BaseAccount{
Address: addr1.String(),
}

genAccs := []authtypes.GenesisAccount{acc}
app := createTestSuite(t, genAccs)
baseApp := app.App.BaseApp
ctx := baseApp.NewContext(false, tmproto.Header{})

addrs := simtestutil.AddTestAddrs(app.BankKeeper, app.StakingKeeper, ctx, 2, sdk.NewInt(1000))
banktestutil.FundModuleAccount(app.BankKeeper, ctx, stakingtypes.BondedPoolName, sdk.NewCoins(sdk.NewCoin(app.StakingKeeper.BondDenom(ctx), sdk.NewInt(1000))))

// create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins
validSendAmount := sdk.NewCoins(sdk.NewCoin(app.StakingKeeper.BondDenom(ctx), sdk.NewInt(1)))
invalidSendAmount := sdk.NewCoins(sdk.NewCoin(app.StakingKeeper.BondDenom(ctx), sdk.NewInt(100)))

// setup our mock bank hooks receiver that prevents the send of 100 coins
bankHooksReceiver := MockBankHooksReceiver{}
baseBankKeeper, ok := app.BankKeeper.(keeper.BaseKeeper)
require.True(t, ok)
keeper.UnsafeSetHooks(
&baseBankKeeper, types.NewMultiBankHooks(&bankHooksReceiver),
)
app.BankKeeper = baseBankKeeper

// try sending a validSendAmount and it should work
err := app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], validSendAmount)
require.NoError(t, err)

// try sending an invalidSendAmount and it should not work
err = app.BankKeeper.SendCoins(ctx, addrs[0], addrs[1], invalidSendAmount)
require.Error(t, err)

// try doing SendManyCoins and make sure if even a single subsend is invalid, the entire function fails
err = app.BankKeeper.SendManyCoins(ctx, addrs[0], []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{invalidSendAmount, validSendAmount})
require.Error(t, err)

// make sure that account to module doesn't bypass hook
err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromAccountToModule(ctx, addrs[0], stakingtypes.BondedPoolName, invalidSendAmount)
require.Error(t, err)

// make sure that module to account doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addrs[0], invalidSendAmount)
require.Error(t, err)

// make sure that module to module doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidSendAmount)
require.Error(t, err)

// make sure that module to many accounts doesn't bypass hook
err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, validSendAmount})
require.NoError(t, err)
err = app.BankKeeper.SendCoinsFromModuleToManyAccounts(ctx, stakingtypes.BondedPoolName, []sdk.AccAddress{addrs[0], addrs[1]}, []sdk.Coins{validSendAmount, invalidSendAmount})
require.Error(t, err)

// make sure that DelegateCoins doesn't bypass the hook
err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.DelegateCoins(ctx, addrs[0], app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidSendAmount)
require.Error(t, err)

// make sure that UndelegateCoins doesn't bypass the hook
err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], validSendAmount)
require.NoError(t, err)
err = app.BankKeeper.UndelegateCoins(ctx, app.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addrs[0], invalidSendAmount)
require.Error(t, err)
}
11 changes: 11 additions & 0 deletions x/bank/keeper/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package keeper

import "github.com/cosmos/cosmos-sdk/x/bank/types"

// UnsafeSetHooks updates the x/bank keeper's hooks, overriding any potential
// pre-existing hooks.
//
// WARNING: this function should only be used in tests.
func UnsafeSetHooks(k *BaseKeeper, h types.BankHooks) {
k.hooks = h
}
43 changes: 41 additions & 2 deletions x/bank/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type Keeper interface {
IterateAllDenomMetaData(ctx sdk.Context, cb func(types.Metadata) bool)

SendCoinsFromModuleToAccount(ctx sdk.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error
SendCoinsFromModuleToManyAccounts(
ctx sdk.Context, senderModule string, recipientAddrs []sdk.AccAddress, amts []sdk.Coins,
) error
SendCoinsFromModuleToModule(ctx sdk.Context, senderModule, recipientModule string, amt sdk.Coins) error
SendCoinsFromAccountToModule(ctx sdk.Context, senderAddr sdk.AccAddress, recipientModule string, amt sdk.Coins) error
DelegateCoinsFromAccountToModule(ctx sdk.Context, senderAddr sdk.AccAddress, recipientModule string, amt sdk.Coins) error
Expand Down Expand Up @@ -161,6 +164,12 @@ func (k BaseKeeper) DelegateCoins(ctx sdk.Context, delegatorAddr, moduleAccAddr
return sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

// call the BeforeSend hooks
err := k.BeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)
if err != nil {
return err
}

balances := sdk.NewCoins()

for _, coin := range amt {
Expand All @@ -186,7 +195,7 @@ func (k BaseKeeper) DelegateCoins(ctx sdk.Context, delegatorAddr, moduleAccAddr
types.NewCoinSpentEvent(delegatorAddr, amt),
)

err := k.addCoins(ctx, moduleAccAddr, amt)
err = k.addCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand All @@ -209,7 +218,13 @@ func (k BaseKeeper) UndelegateCoins(ctx sdk.Context, moduleAccAddr, delegatorAdd
return sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.subUnlockedCoins(ctx, moduleAccAddr, amt)
// call the BeforeSend hooks
err := k.BeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)
if err != nil {
return err
}

err = k.subUnlockedCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand Down Expand Up @@ -340,6 +355,30 @@ func (k BaseKeeper) SendCoinsFromModuleToAccount(
return k.SendCoins(ctx, senderAddr, recipientAddr, amt)
}

// SendCoinsFromModuleToManyAccounts transfers coins from a ModuleAccount to multiple AccAddresses.
// It will panic if the module account does not exist. An error is returned if
// the recipient address is black-listed or if sending the tokens fails.
func (k BaseKeeper) SendCoinsFromModuleToManyAccounts(
ctx sdk.Context, senderModule string, recipientAddrs []sdk.AccAddress, amts []sdk.Coins,
) error {
if len(recipientAddrs) != len(amts) {
panic(fmt.Errorf("addresses and amounts numbers does not match"))
}

senderAddr := k.ak.GetModuleAddress(senderModule)
if senderAddr == nil {
panic(sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", senderModule))
}

for _, recipientAddr := range recipientAddrs {
if k.BlockedAddr(recipientAddr) {
return sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", recipientAddr)
}
}

return k.SendManyCoins(ctx, senderAddr, recipientAddrs, amts)
}

// SendCoinsFromModuleToModule transfers coins from a ModuleAccount to another.
// It will panic if either module account does not exist.
func (k BaseKeeper) SendCoinsFromModuleToModule(
Expand Down
84 changes: 77 additions & 7 deletions x/bank/keeper/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type SendKeeper interface {

InputOutputCoins(ctx sdk.Context, inputs []types.Input, outputs []types.Output) error
SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error
SendManyCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddrs []sdk.AccAddress, amts []sdk.Coins) error

GetParams(ctx sdk.Context) types.Params
SetParams(ctx sdk.Context, params types.Params) error
Expand Down Expand Up @@ -53,6 +54,7 @@ type BaseSendKeeper struct {
cdc codec.BinaryCodec
ak types.AccountKeeper
storeKey storetypes.StoreKey
hooks types.BankHooks

// list of addresses that are restricted from receiving transactions
blockedAddrs map[string]bool
Expand Down Expand Up @@ -88,6 +90,17 @@ func (k BaseSendKeeper) GetAuthority() string {
return k.authority
}

// Set the bank hooks
func (k *BaseSendKeeper) SetHooks(bh types.BankHooks) *BaseSendKeeper {
if k.hooks != nil {
panic("cannot set bank hooks twice")
}

k.hooks = bh

return k
}

// GetParams returns the total set of bank parameters.
func (k BaseSendKeeper) GetParams(ctx sdk.Context) (params types.Params) {
store := ctx.KVStore(k.storeKey)
Expand Down Expand Up @@ -190,7 +203,13 @@ func (k BaseSendKeeper) InputOutputCoins(ctx sdk.Context, inputs []types.Input,
// SendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
err := k.subUnlockedCoins(ctx, fromAddr, amt)
// call the BeforeSend hooks
err := k.BeforeSend(ctx, fromAddr, toAddr, amt)
if err != nil {
return err
}

err = k.subUnlockedCoins(ctx, fromAddr, amt)
if err != nil {
return err
}
Expand All @@ -212,18 +231,69 @@ func (k BaseSendKeeper) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAd

// bech32 encoding is expensive! Only do it once for fromAddr
fromAddrString := fromAddr.String()
ctx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
ctx.EventManager().EmitEvent(sdk.NewEvent(
types.EventTypeTransfer,
sdk.NewAttribute(types.AttributeKeyRecipient, toAddr.String()),
sdk.NewAttribute(types.AttributeKeySender, fromAddrString),
sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()),
))
ctx.EventManager().EmitEvent(sdk.NewEvent(
sdk.EventTypeMessage,
sdk.NewAttribute(types.AttributeKeySender, fromAddrString),
))

return nil
}

// SendManyCoins transfer multiple amt coins from a sending account to multiple receiving accounts.
// An error is returned upon failure.
func (k BaseSendKeeper) SendManyCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddrs []sdk.AccAddress, amts []sdk.Coins) error {
if len(toAddrs) != len(amts) {
return fmt.Errorf("addresses and amounts numbers does not match")
}

totalAmt := sdk.Coins{}
for i, amt := range amts {
// make sure to trigger the BeforeSend hooks for all the sends that are about to occur
err := k.BeforeSend(ctx, fromAddr, toAddrs[i], amts[i])
if err != nil {
return err
}
totalAmt = sdk.Coins.Add(totalAmt, amt...)
}

err := k.subUnlockedCoins(ctx, fromAddr, totalAmt)
if err != nil {
return err
}

fromAddrString := fromAddr.String()
for i, toAddr := range toAddrs {
amt := amts[i]

err := k.addCoins(ctx, toAddr, amt)
if err != nil {
return err
}

acc := k.ak.GetAccount(ctx, toAddr)
if acc == nil {
defer telemetry.IncrCounter(1, "new", "account")
k.ak.SetAccount(ctx, k.ak.NewAccountWithAddress(ctx, toAddr))
}

ctx.EventManager().EmitEvent(sdk.NewEvent(
types.EventTypeTransfer,
sdk.NewAttribute(types.AttributeKeyRecipient, toAddr.String()),
sdk.NewAttribute(types.AttributeKeySender, fromAddrString),
sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()),
),
sdk.NewEvent(
))
ctx.EventManager().EmitEvent(sdk.NewEvent(
sdk.EventTypeMessage,
sdk.NewAttribute(types.AttributeKeySender, fromAddr.String()),
),
})
))

}

return nil
}
Expand Down
Loading

0 comments on commit b1b5c4a

Please sign in to comment.