Skip to content

Commit

Permalink
chore: fix side-effect from original bank.send
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeseung-bae committed Feb 14, 2024
1 parent 2b363a2 commit 840340f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 66 deletions.
24 changes: 12 additions & 12 deletions x/bankplus/keeper/inactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ func inactiveAddrKey(addr sdk.AccAddress) []byte {
}

// isStoredInactiveAddr checks if the address is stored or not as blocked address
func (keeper BaseKeeper) isStoredInactiveAddr(ctx context.Context, address sdk.AccAddress) bool {
store := keeper.storeService.OpenKVStore(ctx)
func (k BaseKeeper) isStoredInactiveAddr(ctx context.Context, address sdk.AccAddress) bool {
store := k.storeService.OpenKVStore(ctx)
bz, _ := store.Get(inactiveAddrKey(address))
return bz != nil
}

// addToInactiveAddr adds a blocked address to the store.
func (keeper BaseKeeper) addToInactiveAddr(ctx context.Context, address sdk.AccAddress) {
store := keeper.storeService.OpenKVStore(ctx)
addrString, err := keeper.addrCdc.BytesToString(address)
func (k BaseKeeper) addToInactiveAddr(ctx context.Context, address sdk.AccAddress) {
store := k.storeService.OpenKVStore(ctx)
addrString, err := k.addrCdc.BytesToString(address)
if err != nil {
panic(err)
}

blockedCAddr := types.InactiveAddr{Address: addrString}
bz := keeper.cdc.MustMarshal(&blockedCAddr)
bz := k.cdc.MustMarshal(&blockedCAddr)
if err := store.Set(inactiveAddrKey(address), bz); err != nil {
panic(err)
}
}

// deleteFromInactiveAddr deletes blocked address from store
func (keeper BaseKeeper) deleteFromInactiveAddr(ctx context.Context, address sdk.AccAddress) {
store := keeper.storeService.OpenKVStore(ctx)
func (k BaseKeeper) deleteFromInactiveAddr(ctx context.Context, address sdk.AccAddress) {
store := k.storeService.OpenKVStore(ctx)
err := store.Delete(inactiveAddrKey(address))
if err != nil {
panic(err)
Expand All @@ -53,16 +53,16 @@ func (keeper BaseKeeper) deleteFromInactiveAddr(ctx context.Context, address sdk
// loadAllInactiveAddrs loads all blocked address and set to `inactiveAddr`.
// This function is executed when the app is initiated and save all inactive address in caches
// in order to prevent to query to store in every time to send
func (keeper BaseKeeper) loadAllInactiveAddrs(ctx context.Context) {
store := keeper.storeService.OpenKVStore(ctx)
func (k BaseKeeper) loadAllInactiveAddrs(ctx context.Context) {
store := k.storeService.OpenKVStore(ctx)
adapter := runtime.KVStoreAdapter(store)
iterator := storetypes.KVStorePrefixIterator(adapter, inactiveAddrsKeyPrefix)

defer iterator.Close()
for ; iterator.Valid(); iterator.Next() {
var bAddr types.InactiveAddr
keeper.cdc.MustUnmarshal(iterator.Value(), &bAddr)
k.cdc.MustUnmarshal(iterator.Value(), &bAddr)

keeper.inactiveAddrs[bAddr.Address] = true
k.inactiveAddrs[bAddr.Address] = true
}
}
32 changes: 16 additions & 16 deletions x/bankplus/keeper/inactive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestBankPlus(t *testing.T) {
type BankPlusTestSuite struct {
suite.Suite
mockCtrl *gomock.Controller
keeper BaseKeeper
cut BaseKeeper
ctx context.Context
}

Expand All @@ -55,7 +55,7 @@ func (s *BankPlusTestSuite) SetupTest() {
BytesToString(authtypes.NewModuleAddress(govtypes.ModuleName))
s.Require().NoError(err)

s.keeper = NewBaseKeeper(
s.cut = NewBaseKeeper(
codec,
kvStoreService,
mockAccKeeper,
Expand All @@ -70,7 +70,7 @@ func (s *BankPlusTestSuite) TearDownTest() {
}

func (s *BankPlusTestSuite) TestInactiveAddress() {
require.Equal(s.T(), 0, len(s.keeper.inactiveAddrs))
require.Equal(s.T(), 0, len(s.cut.inactiveAddrs))
addr := genAddr()
anotherAddr := genAddr()
s.addAddrOk(addr)
Expand All @@ -82,36 +82,36 @@ func (s *BankPlusTestSuite) TestInactiveAddress() {
}

func (s *BankPlusTestSuite) addAddrOk(addr sdk.AccAddress) {
s.keeper.addToInactiveAddr(s.ctx, addr)
require.True(s.T(), s.keeper.isStoredInactiveAddr(s.ctx, addr))
s.cut.addToInactiveAddr(s.ctx, addr)
require.True(s.T(), s.cut.isStoredInactiveAddr(s.ctx, addr))
}

func (s *BankPlusTestSuite) duplicateAddOk(addr sdk.AccAddress) {
s.keeper.addToInactiveAddr(s.ctx, addr)
require.True(s.T(), s.keeper.isStoredInactiveAddr(s.ctx, addr))
s.cut.addToInactiveAddr(s.ctx, addr)
require.True(s.T(), s.cut.isStoredInactiveAddr(s.ctx, addr))
}

func (s *BankPlusTestSuite) deleteAddrOk(addr sdk.AccAddress) {
s.keeper.deleteFromInactiveAddr(s.ctx, addr)
require.False(s.T(), s.keeper.isStoredInactiveAddr(s.ctx, addr))
s.cut.deleteFromInactiveAddr(s.ctx, addr)
require.False(s.T(), s.cut.isStoredInactiveAddr(s.ctx, addr))
}

func (s *BankPlusTestSuite) falseForUnknownAddr(anotherAddr sdk.AccAddress) {
require.False(s.T(), s.keeper.isStoredInactiveAddr(s.ctx, anotherAddr))
require.False(s.T(), s.cut.isStoredInactiveAddr(s.ctx, anotherAddr))
}

func (s *BankPlusTestSuite) noErrorWhenDeletionOfUnknownAddr(anotherAddr sdk.AccAddress) {
require.NotPanicsf(s.T(), func() {
s.keeper.deleteFromInactiveAddr(s.ctx, anotherAddr)
s.cut.deleteFromInactiveAddr(s.ctx, anotherAddr)
}, "no panic expected")
}

func (s *BankPlusTestSuite) testLoadAllInactiveAddrs(addr, anotherAddr sdk.AccAddress) {
s.keeper.addToInactiveAddr(s.ctx, addr)
s.keeper.addToInactiveAddr(s.ctx, anotherAddr)
require.Equal(s.T(), 0, len(s.keeper.inactiveAddrs))
s.keeper.loadAllInactiveAddrs(s.ctx)
require.Equal(s.T(), 2, len(s.keeper.inactiveAddrs))
s.cut.addToInactiveAddr(s.ctx, addr)
s.cut.addToInactiveAddr(s.ctx, anotherAddr)
require.Equal(s.T(), 0, len(s.cut.inactiveAddrs))
s.cut.loadAllInactiveAddrs(s.ctx)
require.Equal(s.T(), 2, len(s.cut.inactiveAddrs))
}

func genAddr() sdk.AccAddress {
Expand Down
72 changes: 36 additions & 36 deletions x/bankplus/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,123 +54,123 @@ func NewBaseKeeper(
}
}

func (keeper BaseKeeper) InitializeBankPlus(ctx context.Context) {
keeper.loadAllInactiveAddrs(ctx)
func (k BaseKeeper) InitializeBankPlus(ctx context.Context) {
k.loadAllInactiveAddrs(ctx)
}

// SendCoinsFromModuleToAccount transfers coins from a ModuleAccount to an AccAddress.
// It will panic if the module account does not exist.
func (keeper BaseKeeper) SendCoinsFromModuleToAccount(
func (k BaseKeeper) SendCoinsFromModuleToAccount(
ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins,
) error {
senderAddr := keeper.ak.GetModuleAddress(senderModule)
senderAddr := k.ak.GetModuleAddress(senderModule)
if senderAddr.Empty() {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", senderModule))
}

if keeper.BlockedAddr(recipientAddr) {
if k.BlockedAddr(recipientAddr) {
return errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", recipientAddr)
}

return keeper.SendCoins(ctx, senderAddr, recipientAddr, amt)
return k.SendCoins(ctx, senderAddr, recipientAddr, amt)
}

// SendCoinsFromModuleToModule transfers coins from a ModuleAccount to another.
// It will panic if either module account does not exist.
func (keeper BaseKeeper) SendCoinsFromModuleToModule(
func (k BaseKeeper) SendCoinsFromModuleToModule(
ctx context.Context, senderModule, recipientModule string, amt sdk.Coins,
) error {
senderAddr := keeper.ak.GetModuleAddress(senderModule)
senderAddr := k.ak.GetModuleAddress(senderModule)
if senderAddr.Empty() {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", senderModule))
}

recipientAcc := keeper.ak.GetModuleAccount(ctx, recipientModule)
recipientAcc := k.ak.GetModuleAccount(ctx, recipientModule)
if recipientAcc == nil {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule))
}

return keeper.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
}

// SendCoinsFromAccountToModule transfers coins from an AccAddress to a ModuleAccount.
// It will panic if the module account does not exist.
func (keeper BaseKeeper) SendCoinsFromAccountToModule(
func (k BaseKeeper) SendCoinsFromAccountToModule(
ctx context.Context, senderAddr sdk.AccAddress, recipientModule string, amt sdk.Coins,
) error {
recipientAcc := keeper.ak.GetModuleAccount(ctx, recipientModule)
recipientAcc := k.ak.GetModuleAccount(ctx, recipientModule)
if recipientAcc == nil {
panic(errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule))
}

return keeper.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
}

func (keeper BaseKeeper) isInactiveAddr(addr sdk.AccAddress) bool {
addrString, err := keeper.addrCdc.BytesToString(addr)
func (k BaseKeeper) isInactiveAddr(addr sdk.AccAddress) bool {
addrString, err := k.addrCdc.BytesToString(addr)
if err != nil {
panic(err)
}
return keeper.inactiveAddrs[addrString]
return k.inactiveAddrs[addrString]
}

// SendCoins transfers amt coins from a sending account to a receiving account.
// This is wrapped bank the `SendKeeper` interface of `bank` module,
// and checks if `toAddr` is a inactiveAddr managed by the module.
func (keeper BaseKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error {
func (k BaseKeeper) SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error {
// if toAddr is smart contract, check the status of contract.
if keeper.isInactiveAddr(toAddr) {
if k.isInactiveAddr(toAddr) {
return errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", toAddr)
}

return keeper.BaseSendKeeper.SendCoins(ctx, fromAddr, toAddr, amt)
return k.BaseSendKeeper.SendCoins(ctx, fromAddr, toAddr, amt)
}

// AddToInactiveAddr adds the address to `inactiveAddr`.
func (keeper BaseKeeper) AddToInactiveAddr(ctx context.Context, addr sdk.AccAddress) {
addrString, err := keeper.addrCdc.BytesToString(addr)
func (k BaseKeeper) AddToInactiveAddr(ctx context.Context, addr sdk.AccAddress) {
addrString, err := k.addrCdc.BytesToString(addr)
if err != nil {
panic(err)
}
if !keeper.inactiveAddrs[addrString] {
keeper.inactiveAddrs[addrString] = true
if !k.inactiveAddrs[addrString] {
k.inactiveAddrs[addrString] = true

keeper.addToInactiveAddr(ctx, addr)
k.addToInactiveAddr(ctx, addr)
}
}

// DeleteFromInactiveAddr removes the address from `inactiveAddr`.
func (keeper BaseKeeper) DeleteFromInactiveAddr(ctx context.Context, addr sdk.AccAddress) {
addrString, err := keeper.addrCdc.BytesToString(addr)
func (k BaseKeeper) DeleteFromInactiveAddr(ctx context.Context, addr sdk.AccAddress) {
addrString, err := k.addrCdc.BytesToString(addr)
if err != nil {
panic(err)
}
if keeper.inactiveAddrs[addrString] {
delete(keeper.inactiveAddrs, addrString)
if k.inactiveAddrs[addrString] {
delete(k.inactiveAddrs, addrString)

keeper.deleteFromInactiveAddr(ctx, addr)
k.deleteFromInactiveAddr(ctx, addr)
}
}

// IsInactiveAddr returns if the address is added in inactiveAddr.
func (keeper BaseKeeper) IsInactiveAddr(addr sdk.AccAddress) bool {
addrString, err := keeper.addrCdc.BytesToString(addr)
func (k BaseKeeper) IsInactiveAddr(addr sdk.AccAddress) bool {
addrString, err := k.addrCdc.BytesToString(addr)
if err != nil {
panic(err)
}
return keeper.inactiveAddrs[addrString]
return k.inactiveAddrs[addrString]
}

func (keeper BaseKeeper) InputOutputCoins(ctx context.Context, input types.Input, outputs []types.Output) error {
if keeper.deactMultiSend {
func (k BaseKeeper) InputOutputCoins(ctx context.Context, input types.Input, outputs []types.Output) error {
if k.deactMultiSend {
return sdkerrors.ErrNotSupported.Wrap("MultiSend was deactivated")
}

for _, out := range outputs {
if keeper.inactiveAddrs[out.Address] {
if k.inactiveAddrs[out.Address] {
return errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", out.Address)
}
}

return keeper.BaseSendKeeper.InputOutputCoins(ctx, input, outputs)
return k.BaseSendKeeper.InputOutputCoins(ctx, input, outputs)
}
6 changes: 4 additions & 2 deletions x/bankplus/module/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (a AppModule) RegisterInvariants(ir sdk.InvariantRegistry) {
func (a AppModule) RegisterServices(cfg module.Configurator) {
banktypes.RegisterMsgServer(cfg.MsgServer(), bankkeeper.NewMsgServerImpl(a.bankKeeper))
banktypes.RegisterQueryServer(cfg.QueryServer(), a.bankKeeper)
m := bankkeeper.NewMigrator(a.bankKeeper.(keeper.BaseKeeper).BaseKeeper, a.legacySubspace)
m := bankkeeper.NewMigrator(a.bankKeeper.(bankkeeper.BaseKeeper), a.legacySubspace)
if err := cfg.RegisterMigration(banktypes.ModuleName, 1, m.Migrate1to2); err != nil {
panic(fmt.Sprintf("failed to migrate x/bank from version 1 to 2: %v", err))
}
Expand Down Expand Up @@ -163,7 +163,9 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
in.Logger,
)

m := NewAppModule(in.Cdc, bankKeeper, in.AccountKeeper, in.LegacySubspace)
originalBankKeeper := bankkeeper.NewBaseKeeper(in.Cdc, in.StoreService, in.AccountKeeper, blockedAddresses, authorityString, in.Logger)
m := NewAppModule(in.Cdc, originalBankKeeper, in.AccountKeeper, in.LegacySubspace)

return ModuleOutputs{
BankKeeper: bankKeeper,
Module: m,
Expand Down

0 comments on commit 840340f

Please sign in to comment.