Skip to content

Commit

Permalink
Add extension point for custom ibc port name
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Nov 10, 2023
1 parent 7c8f1e8 commit 0b09ae1
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 35 deletions.
2 changes: 0 additions & 2 deletions x/wasm/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ var (
// Deprecated: Do not use.
NewQuerier = keeper.Querier
// Deprecated: Do not use.
ContractFromPortID = keeper.ContractFromPortID
// Deprecated: Do not use.
WithWasmEngine = keeper.WithWasmEngine
// Deprecated: Do not use.
NewCountTXDecorator = keeper.NewCountTXDecorator
Expand Down
19 changes: 9 additions & 10 deletions x/wasm/ibc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/CosmWasm/wasmd/x/wasm/keeper"
"github.com/CosmWasm/wasmd/x/wasm/types"
)

Expand Down Expand Up @@ -51,7 +50,7 @@ func (i IBCHandler) OnChanOpenInit(
if err := ValidateChannelParams(channelID); err != nil {
return "", err
}
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return "", errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -103,7 +102,7 @@ func (i IBCHandler) OnChanOpenTry(
return "", err
}

contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return "", errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -151,7 +150,7 @@ func (i IBCHandler) OnChanOpenAck(
counterpartyChannelID string,
counterpartyVersion string,
) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -177,7 +176,7 @@ func (i IBCHandler) OnChanOpenAck(

// OnChanOpenConfirm implements the IBCModule interface
func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -199,7 +198,7 @@ func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string)

// OnChanCloseInit implements the IBCModule interface
func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -227,7 +226,7 @@ func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) e
// OnChanCloseConfirm implements the IBCModule interface
func (i IBCHandler) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string) error {
// counterparty has closed the channel
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -268,7 +267,7 @@ func (i IBCHandler) OnRecvPacket(
packet channeltypes.Packet,
relayer sdk.AccAddress,
) ibcexported.Acknowledgement {
contractAddr, err := keeper.ContractFromPortID(packet.DestinationPort)
contractAddr, err := i.keeper.ContractFromPortID(packet.DestinationPort)
if err != nil {
// this must not happen as ports were registered before
panic(errorsmod.Wrapf(err, "contract port id"))
Expand Down Expand Up @@ -296,7 +295,7 @@ func (i IBCHandler) OnAcknowledgementPacket(
acknowledgement []byte,
relayer sdk.AccAddress,
) error {
contractAddr, err := keeper.ContractFromPortID(packet.SourcePort)
contractAddr, err := i.keeper.ContractFromPortID(packet.SourcePort)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -314,7 +313,7 @@ func (i IBCHandler) OnAcknowledgementPacket(

// OnTimeoutPacket implements the IBCModule interface
func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, relayer sdk.AccAddress) error {
contractAddr, err := keeper.ContractFromPortID(packet.SourcePort)
contractAddr, err := i.keeper.ContractFromPortID(packet.SourcePort)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down
11 changes: 10 additions & 1 deletion x/wasm/ibc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func TestOnRecvPacket(t *testing.T) {
ctx.EventManager().EmitEvent(myCustomEvent)
return spec.contractRsp, spec.contractOkMsgExecErr
},
ContractFromPortIDFn: keeper.DefaultIBCPortNameGenerator{}.ContractFromPortID,
}
h := NewIBCHandler(mock, nil, nil)
em := &sdk.EventManager{}
Expand Down Expand Up @@ -200,7 +201,15 @@ var _ types.IBCContractKeeper = &IBCContractKeeperMock{}

type IBCContractKeeperMock struct {
types.IBCContractKeeper
OnRecvPacketFn func(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error)
OnRecvPacketFn func(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error)
ContractFromPortIDFn func(portID string) (sdk.AccAddress, error)
}

func (m IBCContractKeeperMock) ContractFromPortID(portID string) (sdk.AccAddress, error) {
if m.ContractFromPortIDFn == nil {
panic("not expected to be called")
}
return m.ContractFromPortIDFn(portID)
}

func (m IBCContractKeeperMock) OnRecvPacket(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error) {
Expand Down
3 changes: 2 additions & 1 deletion x/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ func NewDefaultMessageHandler(
bankKeeper types.Burner,
cdc codec.Codec,
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
customEncoders ...*MessageEncoders,
) Messenger {
encoders := DefaultEncoders(cdc, portSource)
encoders := DefaultEncoders(cdc, portSource, ibcPortAllocator)
for _, e := range customEncoders {
encoders = encoders.Merge(e)
}
Expand Down
16 changes: 12 additions & 4 deletions x/wasm/keeper/handler_plugin_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ type MessageEncoders struct {
Gov func(sender sdk.AccAddress, msg *wasmvmtypes.GovMsg) ([]sdk.Msg, error)
}

func DefaultEncoders(unpacker codectypes.AnyUnpacker, portSource types.ICS20TransferPortSource) MessageEncoders {
// DefaultEncoders setup the wasm module default message encoders
func DefaultEncoders(
unpacker codectypes.AnyUnpacker,
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
) MessageEncoders {
return MessageEncoders{
Bank: EncodeBankMsg,
Custom: NoCustomMsg,
Distribution: EncodeDistributionMsg,
IBC: EncodeIBCMsg(portSource),
IBC: EncodeIBCMsg(portSource, ibcPortAllocator),
Staking: EncodeStakingMsg,
Stargate: EncodeStargateMsg(unpacker),
Wasm: EncodeWasmMsg,
Expand Down Expand Up @@ -295,12 +300,15 @@ func EncodeWasmMsg(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg,
}
}

func EncodeIBCMsg(portSource types.ICS20TransferPortSource) func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
func EncodeIBCMsg(
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
) func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
return func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
switch {
case msg.CloseChannel != nil:
return []sdk.Msg{&channeltypes.MsgChannelCloseInit{
PortId: PortIDForContract(sender),
PortId: ibcPortAllocator.PortIDForContract(sender),
ChannelId: msg.CloseChannel.ChannelID,
Signer: sender.String(),
}}, nil
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/keeper/handler_plugin_encoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ func TestEncoding(t *testing.T) {
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
var ctx sdk.Context
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource)
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource, DefaultIBCPortNameGenerator{})
res, err := encoder.Encode(ctx, tc.sender, tc.srcContractIBCPort, tc.srcMsg)
if tc.expError {
assert.Error(t, err)
Expand Down Expand Up @@ -773,7 +773,7 @@ func TestEncodeGovMsg(t *testing.T) {
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
var ctx sdk.Context
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource)
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource, DefaultIBCPortNameGenerator{})
res, gotEncErr := encoder.Encode(ctx, tc.sender, "myIBCPort", tc.srcMsg)
if tc.expError {
assert.Error(t, gotEncErr)
Expand Down
17 changes: 12 additions & 5 deletions x/wasm/keeper/ibc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,32 @@ func (k Keeper) bindIbcPort(ctx sdk.Context, portID string) error {
return k.ClaimCapability(ctx, portCap, host.PortPath(portID))
}

// ensureIbcPort is like registerIbcPort, but it checks if we already hold the port
// ensureIBCPort is like registerIbcPort, but it checks if we already hold the port
// before calling register, so this is safe to call multiple times.
// Returns success if we already registered or just registered and error if we cannot
// (lack of permissions or someone else has it)
func (k Keeper) ensureIbcPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) {
portID := PortIDForContract(contractAddr)
func (k Keeper) ensureIBCPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) {
portID := k.ibcPortNameGenerator.PortIDForContract(contractAddr)
if _, ok := k.capabilityKeeper.GetCapability(ctx, host.PortPath(portID)); ok {
return portID, nil
}
return portID, k.bindIbcPort(ctx, portID)
}

type IBCPortNameGenerator interface {
PortIDForContract(addr sdk.AccAddress) string
ContractFromPortID(portID string) (sdk.AccAddress, error)
}

const portIDPrefix = "wasm."

func PortIDForContract(addr sdk.AccAddress) string {
type DefaultIBCPortNameGenerator struct{}

func (DefaultIBCPortNameGenerator) PortIDForContract(addr sdk.AccAddress) string {
return portIDPrefix + addr.String()
}

func ContractFromPortID(portID string) (sdk.AccAddress, error) {
func (DefaultIBCPortNameGenerator) ContractFromPortID(portID string) (sdk.AccAddress, error) {
if !strings.HasPrefix(portID, portIDPrefix) {
return nil, errorsmod.Wrapf(types.ErrInvalid, "without prefix")
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/keeper/ibc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestBindingPortForIBCContractOnInstantiate(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, example.Contract, addr)

portID2 := PortIDForContract(addr)
portID2 := keepers.WasmKeeper.ibcPortNameGenerator.PortIDForContract(addr)
owner, _, err = keepers.IBCKeeper.PortKeeper.LookupModuleByPort(ctx, portID2)
require.NoError(t, err)
require.Equal(t, "wasm", owner)
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestContractFromPortID(t *testing.T) {
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
gotAddr, gotErr := ContractFromPortID(spec.srcPort)
gotAddr, gotErr := DefaultIBCPortNameGenerator{}.ContractFromPortID(spec.srcPort)
if spec.expErr {
require.Error(t, gotErr)
return
Expand Down
11 changes: 8 additions & 3 deletions x/wasm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ type Keeper struct {

// the address capable of executing a MsgUpdateParams message. Typically, this
// should be the x/gov module account.
authority string
authority string
ibcPortNameGenerator IBCPortNameGenerator
}

func (k Keeper) getUploadAccessConfig(ctx context.Context) types.AccessConfig {
Expand Down Expand Up @@ -334,7 +335,7 @@ func (k Keeper) instantiate(
}
if report.HasIBCEntryPoints {
// register IBC port
ibcPort, err := k.ensureIbcPort(sdkCtx, contractAddress)
ibcPort, err := k.ensureIBCPort(sdkCtx, contractAddress)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -457,7 +458,7 @@ func (k Keeper) migrate(
return nil, errorsmod.Wrap(types.ErrMigrationFailed, "requires ibc callbacks")
case report.HasIBCEntryPoints && contractInfo.IBCPortID == "":
// add ibc port
ibcPort, err := k.ensureIbcPort(sdkCtx, contractAddress)
ibcPort, err := k.ensureIBCPort(sdkCtx, contractAddress)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1194,6 +1195,10 @@ func (k Keeper) importContract(ctx context.Context, contractAddr sdk.AccAddress,
return k.importContractState(ctx, contractAddr, state)
}

func (k Keeper) ContractFromPortID(portID string) (sdk.AccAddress, error) {
return k.ibcPortNameGenerator.ContractFromPortID(portID)
}

func (k Keeper) newQueryHandler(ctx sdk.Context, contractAddress sdk.AccAddress) QueryHandler {
return NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress, k.gasRegister)
}
Expand Down
6 changes: 4 additions & 2 deletions x/wasm/keeper/keeper_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ func NewKeeper(
accountPruner: NewVestingCoinBurner(bankKeeper),
portKeeper: portKeeper,
capabilityKeeper: capabilityKeeper,
messenger: NewDefaultMessageHandler(router, ics4Wrapper, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
queryGasLimit: wasmConfig.SmartQueryGasLimit,
gasRegister: types.NewDefaultWasmGasRegister(),
maxQueryStackSize: types.DefaultMaxQueryStackSize,
Expand All @@ -56,8 +55,11 @@ func NewKeeper(
propagateGovAuthorization: map[types.AuthorizationPolicyAction]struct{}{
types.AuthZActionInstantiate: {},
},
authority: authority,
authority: authority,
ibcPortNameGenerator: DefaultIBCPortNameGenerator{},
}
keeper.messenger = NewDefaultMessageHandler(router, ics4Wrapper, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource, keeper.ibcPortNameGenerator)

keeper.wasmVMQueryHandler = DefaultQueryPlugins(bankKeeper, stakingKeeper, distrKeeper, channelKeeper, keeper)
preOpts, postOpts := splitOpts(opts)
for _, o := range preOpts {
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func TestInstantiateWithContractFactoryChildQueriesParent(t *testing.T) {
router := baseapp.NewMsgServiceRouter()
router.SetInterfaceRegistry(keepers.EncodingConfig.InterfaceRegistry)
types.RegisterMsgServer(router, NewMsgServerImpl(keeper))
keeper.messenger = NewDefaultMessageHandler(router, nil, nil, nil, nil, keepers.EncodingConfig.Codec, nil)
keeper.messenger = NewDefaultMessageHandler(router, nil, nil, nil, nil, keepers.EncodingConfig.Codec, nil, keeper.ibcPortNameGenerator)
// overwrite wasmvm in response handler
keeper.wasmVMResponseHandler = NewDefaultWasmVMContractResponseHandler(NewMessageDispatcher(keeper.messenger, keeper))

Expand Down
8 changes: 8 additions & 0 deletions x/wasm/keeper/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ func WithMaxQueryStackSize(m uint32) Option {
})
}

// WithCustomIBCPortNameGenerator overwrites the default ibc port name generator for wasm contracts with
// the custom one
func WithCustomIBCPortNameGenerator(c IBCPortNameGenerator) Option {
return optsFn(func(k *Keeper) {
k.ibcPortNameGenerator = c
})
}

// WithAcceptedAccountTypesOnContractInstantiation sets the accepted account types. Account types of this list won't be overwritten or cause a failure
// when they exist for an address on contract instantiation.
//
Expand Down
8 changes: 8 additions & 0 deletions x/wasm/keeper/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ func TestConstructorOptions(t *testing.T) {
storeKey := storetypes.NewKVStoreKey(types.StoreKey)
codec := MakeEncodingConfig(t).Codec

otherIBCPortNameGenerator := struct{ IBCPortNameGenerator }{}

specs := map[string]struct {
srcOpt Option
verify func(*testing.T, Keeper)
Expand Down Expand Up @@ -140,6 +142,12 @@ func TestConstructorOptions(t *testing.T) {
assert.Equal(t, exp, k.propagateGovAuthorization)
},
},
"ibc port name": {
srcOpt: WithCustomIBCPortNameGenerator(otherIBCPortNameGenerator),
verify: func(t *testing.T, k Keeper) {
assert.Equal(t, otherIBCPortNameGenerator, k.ibcPortNameGenerator)
},
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/relay_pingpong_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func TestPinPong(t *testing.T) {
pongContract.contractAddr = pongContractAddr

var (
sourcePortID = wasmkeeper.PortIDForContract(pingContractAddr)
counterpartyPortID = wasmkeeper.PortIDForContract(pongContractAddr)
sourcePortID = wasmkeeper.DefaultIBCPortNameGenerator{}.PortIDForContract(pingContractAddr)
counterpartyPortID = wasmkeeper.DefaultIBCPortNameGenerator{}.PortIDForContract(pongContractAddr)
)

path := wasmibctesting.NewPath(chainA, chainB)
Expand Down
3 changes: 3 additions & 0 deletions x/wasm/types/exported_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ type IBCContractKeeper interface {
ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error
// AuthenticateCapability wraps the scopedKeeper's AuthenticateCapability function
AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) bool

// ContractFromPortID resolves contract address from the ibc port-di
ContractFromPortID(portID string) (sdk.AccAddress, error)
}

0 comments on commit 0b09ae1

Please sign in to comment.