Skip to content

Commit

Permalink
feat: Add ParseKeyFeeEnabled and rename FeeEnabledKey -> KeyFeeEnabled (
Browse files Browse the repository at this point in the history
#1023)

* chore: add ParseKeyFeesInEscrow helper function

* feat: add ParseKeyFeeEnabled function and rename FeeEnabledKey to KeyFeeEnabled
  • Loading branch information
colin-axner authored Mar 1, 2022
1 parent e51e2c9 commit 5f8fc9f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
16 changes: 9 additions & 7 deletions modules/apps/29-fee/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,20 @@ func (k Keeper) GetFeeModuleAddress() sdk.AccAddress {
// identified by channel and port identifiers.
func (k Keeper) SetFeeEnabled(ctx sdk.Context, portID, channelID string) {
store := ctx.KVStore(k.storeKey)
store.Set(types.FeeEnabledKey(portID, channelID), []byte{1})
store.Set(types.KeyFeeEnabled(portID, channelID), []byte{1})
}

// DeleteFeeEnabled deletes the fee enabled flag for a given portID and channelID
func (k Keeper) DeleteFeeEnabled(ctx sdk.Context, portID, channelID string) {
store := ctx.KVStore(k.storeKey)
store.Delete(types.FeeEnabledKey(portID, channelID))
store.Delete(types.KeyFeeEnabled(portID, channelID))
}

// IsFeeEnabled returns whether fee handling logic should be run for the given port. It will check the
// fee enabled flag for the given port and channel identifiers
func (k Keeper) IsFeeEnabled(ctx sdk.Context, portID, channelID string) bool {
store := ctx.KVStore(k.storeKey)
return store.Get(types.FeeEnabledKey(portID, channelID)) != nil
return store.Get(types.KeyFeeEnabled(portID, channelID)) != nil
}

// GetAllFeeEnabledChannels returns a list of all ics29 enabled channels containing portID & channelID that are stored in state
Expand All @@ -105,11 +105,13 @@ func (k Keeper) GetAllFeeEnabledChannels(ctx sdk.Context) []types.FeeEnabledChan

var enabledChArr []types.FeeEnabledChannel
for ; iterator.Valid(); iterator.Next() {
keySplit := strings.Split(string(iterator.Key()), "/")

portID, channelID, err := types.ParseKeyFeeEnabled(string(iterator.Key()))
if err != nil {
panic(err)
}
ch := types.FeeEnabledChannel{
PortId: keySplit[1],
ChannelId: keySplit[2],
PortId: portID,
ChannelId: channelID,
}

enabledChArr = append(enabledChArr, ch)
Expand Down
24 changes: 22 additions & 2 deletions modules/apps/29-fee/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,32 @@ const (
AttributeKeyTimeoutFee = "timeout_fee"
)

// FeeEnabledKey returns the key that stores a flag to determine if fee logic should
// KeyFeeEnabled returns the key that stores a flag to determine if fee logic should
// be enabled for the given port and channel identifiers.
func FeeEnabledKey(portID, channelID string) []byte {
func KeyFeeEnabled(portID, channelID string) []byte {
return []byte(fmt.Sprintf("%s/%s/%s", FeeEnabledKeyPrefix, portID, channelID))
}

// ParseKeyFeeEnabled parses the key used to indicate if the fee logic should be
// enabled for the given port and channel identifiers.
func ParseKeyFeeEnabled(key string) (portID, channelID string, err error) {
keySplit := strings.Split(key, "/")
if len(keySplit) != 3 {
return "", "", sdkerrors.Wrapf(
sdkerrors.ErrLogic, "key provided is incorrect: the key split has incorrect length, expected %d, got %d", 3, len(keySplit),
)
}

if keySplit[0] != FeeEnabledKeyPrefix {
return "", "", sdkerrors.Wrapf(sdkerrors.ErrLogic, "key prefix is incorrect: expected %s, got %s", FeeEnabledKeyPrefix, keySplit[0])
}

portID = keySplit[1]
channelID = keySplit[2]

return portID, channelID, nil
}

// KeyCounterpartyRelayer returns the key for relayer address -> counteryparty address mapping
func KeyCounterpartyRelayer(address, channelID string) []byte {
return []byte(fmt.Sprintf("%s/%s/%s", CounterpartyRelayerAddressKeyPrefix, address, channelID))
Expand Down
45 changes: 43 additions & 2 deletions modules/apps/29-fee/types/keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
ibctesting "github.com/cosmos/ibc-go/v3/testing"
)

var (
validPacketID = channeltypes.NewPacketId(ibctesting.FirstChannelID, ibctesting.MockFeePort, 1)
)

func TestKeyCounterpartyRelayer(t *testing.T) {
var (
relayerAddress = "relayer_address"
Expand All @@ -21,8 +25,45 @@ func TestKeyCounterpartyRelayer(t *testing.T) {
require.Equal(t, string(key), fmt.Sprintf("%s/%s/%s", types.CounterpartyRelayerAddressKeyPrefix, relayerAddress, channelID))
}

func TestParseKeyFeeEnabled(t *testing.T) {
testCases := []struct {
name string
key string
expPass bool
}{
{
"success",
string(types.KeyFeeEnabled(ibctesting.MockPort, ibctesting.FirstChannelID)),
true,
},
{
"incorrect key - key split has incorrect length",
string(types.KeyFeesInEscrow(validPacketID)),
false,
},
{
"incorrect key - key split has incorrect length",
fmt.Sprintf("%s/%s/%s", "fee", ibctesting.MockPort, ibctesting.FirstChannelID),
false,
},
}

for _, tc := range testCases {
portID, channelID, err := types.ParseKeyFeeEnabled(tc.key)

if tc.expPass {
require.NoError(t, err)
require.Equal(t, ibctesting.MockPort, portID)
require.Equal(t, ibctesting.FirstChannelID, channelID)
} else {
require.Error(t, err)
require.Empty(t, portID)
require.Empty(t, channelID)
}
}
}

func TestParseKeyFeesInEscrow(t *testing.T) {
validPacketID := channeltypes.NewPacketId(ibctesting.FirstChannelID, ibctesting.MockFeePort, 1)

testCases := []struct {
name string
Expand All @@ -36,7 +77,7 @@ func TestParseKeyFeesInEscrow(t *testing.T) {
},
{
"incorrect key - key split has incorrect length",
string(types.FeeEnabledKey(validPacketID.PortId, validPacketID.ChannelId)),
string(types.KeyFeeEnabled(validPacketID.PortId, validPacketID.ChannelId)),
false,
},
{
Expand Down

0 comments on commit 5f8fc9f

Please sign in to comment.