Skip to content

Commit

Permalink
fix: avoid race conditions in ics27 handshakes (#2682) (#2698)
Browse files Browse the repository at this point in the history
* wip adding conditional to msg server and go apis, adding tests

* cleanup

* cleanup middleware enabled code

* adding additional test case for reopening channel via msg server

* Update modules/apps/27-interchain-accounts/controller/keeper/keeper.go

Co-authored-by: Cian Hatton <[email protected]>

* updating error msgs and test case assertion

* updating InitGenesis to set middleware disabled

Co-authored-by: Cian Hatton <[email protected]>
(cherry picked from commit c9b8064)

Co-authored-by: Damian Nolan <[email protected]>
Co-authored-by: colin axnér <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2022
1 parent 2e0e66d commit 2eb5217
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/suite"

"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller"
controllerkeeper "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/keeper"
"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types"
icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types"
fee "github.com/cosmos/ibc-go/v6/modules/apps/29-fee"
Expand Down Expand Up @@ -840,3 +841,80 @@ func (suite *InterchainAccountsTestSuite) TestGetAppVersion() {
suite.Require().True(found)
suite.Require().Equal(path.EndpointA.ChannelConfig.Version, appVersion)
}

func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsGoAPICaller() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

// initiate a channel handshake such that channel.State == INIT
err := RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().NoError(err)

// attempt to start a second handshake via the controller msg server
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().Error(err)
suite.Require().Nil(res)
}

func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsMsgServerCaller() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

// initiate a channel handshake such that channel.State == INIT
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().NotNil(res)
suite.Require().NoError(err)

// attempt to start a second handshake via the legacy Go API
err = RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().Error(err)
}

func (suite *InterchainAccountsTestSuite) TestClosedChannelReopensWithMsgServer() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

err := SetupICAPath(path, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().NoError(err)

// set the channel state to closed
err = path.EndpointA.SetChannelClosed()
suite.Require().NoError(err)
err = path.EndpointB.SetChannelClosed()
suite.Require().NoError(err)

// reset endpoint channel ids
path.EndpointA.ChannelID = ""
path.EndpointB.ChannelID = ""

// fetch the next channel sequence before reinitiating the channel handshake
channelSeq := suite.chainA.GetSimApp().GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(suite.chainA.GetContext())

// route a new MsgRegisterInterchainAccount in order to reopen the
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), path.EndpointA.ChannelConfig.Version)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().NoError(err)
suite.Require().Equal(channeltypes.FormatChannelIdentifier(channelSeq), res.ChannelId)

// assign the channel sequence to endpointA before generating proofs and initiating the TRY step
path.EndpointA.ChannelID = channeltypes.FormatChannelIdentifier(channelSeq)

path.EndpointA.Chain.NextBlock()

err = path.EndpointB.ChanOpenTry()
suite.Require().NoError(err)

err = path.EndpointA.ChanOpenAck()
suite.Require().NoError(err)

err = path.EndpointB.ChanOpenConfirm()
suite.Require().NoError(err)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (k Keeper) RegisterInterchainAccount(ctx sdk.Context, connectionID, owner,
return err
}

if k.IsMiddlewareDisabled(ctx, portID, connectionID) && !k.IsActiveChannelClosed(ctx, connectionID, portID) {
return sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight")
}

k.SetMiddlewareEnabled(ctx, portID, connectionID)

_, err = k.registerInterchainAccount(ctx, connectionID, portID, version)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, state genesistypes.ControllerGe

if ch.IsMiddlewareEnabled {
keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ConnectionId)
} else {
keeper.SetMiddlewareDisabled(ctx, ch.PortId, ch.ConnectionId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ func (suite *KeeperTestSuite) TestInitGenesis() {
ChannelId: ibctesting.FirstChannelID,
IsMiddlewareEnabled: true,
},
{
ConnectionId: "connection-1",
PortId: "test-port-1",
ChannelId: "channel-1",
IsMiddlewareEnabled: false,
},
},
InterchainAccounts: []genesistypes.RegisteredInterchainAccount{
{
Expand All @@ -40,6 +46,9 @@ func (suite *KeeperTestSuite) TestInitGenesis() {
isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(suite.chainA.GetContext(), TestPortID, ibctesting.FirstConnectionID)
suite.Require().True(isMiddlewareEnabled)

isMiddlewareDisabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareDisabled(suite.chainA.GetContext(), "test-port-1", "connection-1")
suite.Require().True(isMiddlewareDisabled)

accountAdrr, found := suite.chainA.GetSimApp().ICAControllerKeeper.GetInterchainAccountAddress(suite.chainA.GetContext(), ibctesting.FirstConnectionID, TestPortID)
suite.Require().True(found)
suite.Require().Equal(interchainAccAddr.String(), accountAdrr)
Expand Down
28 changes: 26 additions & 2 deletions modules/apps/27-interchain-accounts/controller/keeper/keeper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper

import (
"bytes"
"fmt"
"strings"

Expand Down Expand Up @@ -146,6 +147,17 @@ func (k Keeper) GetOpenActiveChannel(ctx sdk.Context, connectionID, portID strin
return "", false
}

// IsActiveChannelClosed retrieves the active channel from the store and returns true if the channel state is CLOSED, otherwise false
func (k Keeper) IsActiveChannelClosed(ctx sdk.Context, connectionID, portID string) bool {
channelID, found := k.GetActiveChannelID(ctx, connectionID, portID)
if !found {
return false
}

channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID)
return found && channel.State == channeltypes.CLOSED
}

// GetAllActiveChannels returns a list of all active interchain accounts controller channels and their associated connection and port identifiers
func (k Keeper) GetAllActiveChannels(ctx sdk.Context) []genesistypes.ActiveChannel {
store := ctx.KVStore(k.storeKey)
Expand Down Expand Up @@ -227,13 +239,25 @@ func (k Keeper) SetInterchainAccountAddress(ctx sdk.Context, connectionID, portI
// IsMiddlewareEnabled returns true if the underlying application callbacks are enabled for given port and connection identifier pair, otherwise false
func (k Keeper) IsMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))
return bytes.Equal(icatypes.MiddlewareEnabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID)))
}

// IsMiddlewareDisabled returns true if the underlying application callbacks are disabled for the given port and connection identifier pair, otherwise false
func (k Keeper) IsMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) bool {
store := ctx.KVStore(k.storeKey)
return bytes.Equal(icatypes.MiddlewareDisabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID)))
}

// SetMiddlewareEnabled stores a flag to indicate that the underlying application callbacks should be enabled for the given port and connection identifier pair
func (k Keeper) SetMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), []byte{byte(1)})
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareEnabled)
}

// SetMiddlewareDisabled stores a flag to indicate that the underlying application callbacks should be disabled for the given port and connection identifier pair
func (k Keeper) SetMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareDisabled)
}

// DeleteMiddlewareEnabled deletes the middleware enabled flag stored in state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types"
icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types"
Expand All @@ -30,6 +31,12 @@ func (s msgServer) RegisterInterchainAccount(goCtx context.Context, msg *types.M
return nil, err
}

if s.IsMiddlewareEnabled(ctx, portID, msg.ConnectionId) && !s.IsActiveChannelClosed(ctx, msg.ConnectionId, portID) {
return nil, sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight")
}

s.SetMiddlewareDisabled(ctx, portID, msg.ConnectionId)

channelID, err := s.registerInterchainAccount(ctx, msg.ConnectionId, portID, msg.Version)
if err != nil {
return nil, err
Expand Down
6 changes: 6 additions & 0 deletions modules/apps/27-interchain-accounts/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ var (

// IsMiddlewareEnabledPrefix defines the key prefix used to store a flag for legacy API callback routing via ibc middleware
IsMiddlewareEnabledPrefix = "isMiddlewareEnabled"

// MiddlewareEnabled is the value used to signal that controller middleware is enabled
MiddlewareEnabled = []byte{0x01}

// MiddlewareDisabled is the value used to signal that controller midleware is disabled
MiddlewareDisabled = []byte{0x02}
)

// KeyActiveChannel creates and returns a new key used for active channels store operations
Expand Down

0 comments on commit 2eb5217

Please sign in to comment.