Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension point for custom ibc port name #1710

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions x/wasm/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ var (
// Deprecated: Do not use.
NewQuerier = keeper.Querier
// Deprecated: Do not use.
ContractFromPortID = keeper.ContractFromPortID
// Deprecated: Do not use.
WithWasmEngine = keeper.WithWasmEngine
// Deprecated: Do not use.
NewCountTXDecorator = keeper.NewCountTXDecorator
Expand Down
19 changes: 9 additions & 10 deletions x/wasm/ibc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/CosmWasm/wasmd/x/wasm/keeper"
"github.com/CosmWasm/wasmd/x/wasm/types"
)

Expand Down Expand Up @@ -51,7 +50,7 @@ func (i IBCHandler) OnChanOpenInit(
if err := ValidateChannelParams(channelID); err != nil {
return "", err
}
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return "", errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -103,7 +102,7 @@ func (i IBCHandler) OnChanOpenTry(
return "", err
}

contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return "", errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -151,7 +150,7 @@ func (i IBCHandler) OnChanOpenAck(
counterpartyChannelID string,
counterpartyVersion string,
) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -177,7 +176,7 @@ func (i IBCHandler) OnChanOpenAck(

// OnChanOpenConfirm implements the IBCModule interface
func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -199,7 +198,7 @@ func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string)

// OnChanCloseInit implements the IBCModule interface
func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) error {
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -227,7 +226,7 @@ func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) e
// OnChanCloseConfirm implements the IBCModule interface
func (i IBCHandler) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string) error {
// counterparty has closed the channel
contractAddr, err := keeper.ContractFromPortID(portID)
contractAddr, err := i.keeper.ContractFromPortID(ctx, portID)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down Expand Up @@ -268,7 +267,7 @@ func (i IBCHandler) OnRecvPacket(
packet channeltypes.Packet,
relayer sdk.AccAddress,
) ibcexported.Acknowledgement {
contractAddr, err := keeper.ContractFromPortID(packet.DestinationPort)
contractAddr, err := i.keeper.ContractFromPortID(ctx, packet.DestinationPort)
if err != nil {
// this must not happen as ports were registered before
panic(errorsmod.Wrapf(err, "contract port id"))
Expand Down Expand Up @@ -296,7 +295,7 @@ func (i IBCHandler) OnAcknowledgementPacket(
acknowledgement []byte,
relayer sdk.AccAddress,
) error {
contractAddr, err := keeper.ContractFromPortID(packet.SourcePort)
contractAddr, err := i.keeper.ContractFromPortID(ctx, packet.SourcePort)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand All @@ -314,7 +313,7 @@ func (i IBCHandler) OnAcknowledgementPacket(

// OnTimeoutPacket implements the IBCModule interface
func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, relayer sdk.AccAddress) error {
contractAddr, err := keeper.ContractFromPortID(packet.SourcePort)
contractAddr, err := i.keeper.ContractFromPortID(ctx, packet.SourcePort)
if err != nil {
return errorsmod.Wrapf(err, "contract port id")
}
Expand Down
12 changes: 11 additions & 1 deletion x/wasm/ibc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wasm

import (
"context"
"testing"

wasmvmtypes "github.com/CosmWasm/wasmvm/types"
Expand Down Expand Up @@ -107,6 +108,7 @@ func TestOnRecvPacket(t *testing.T) {
ctx.EventManager().EmitEvent(myCustomEvent)
return spec.contractRsp, spec.contractOkMsgExecErr
},
ContractFromPortIDFn: keeper.DefaultIBCPortNameGenerator{}.ContractFromPortID,
}
h := NewIBCHandler(mock, nil, nil)
em := &sdk.EventManager{}
Expand Down Expand Up @@ -200,7 +202,15 @@ var _ types.IBCContractKeeper = &IBCContractKeeperMock{}

type IBCContractKeeperMock struct {
types.IBCContractKeeper
OnRecvPacketFn func(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error)
OnRecvPacketFn func(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error)
ContractFromPortIDFn func(ctx context.Context, portID string) (sdk.AccAddress, error)
}

func (m IBCContractKeeperMock) ContractFromPortID(ctx context.Context, portID string) (sdk.AccAddress, error) {
if m.ContractFromPortIDFn == nil {
panic("not expected to be called")
}
return m.ContractFromPortIDFn(ctx, portID)
}

func (m IBCContractKeeperMock) OnRecvPacket(ctx sdk.Context, contractAddr sdk.AccAddress, msg wasmvmtypes.IBCPacketReceiveMsg) (ibcexported.Acknowledgement, error) {
Expand Down
3 changes: 2 additions & 1 deletion x/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ func NewDefaultMessageHandler(
bankKeeper types.Burner,
cdc codec.Codec,
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
customEncoders ...*MessageEncoders,
) Messenger {
encoders := DefaultEncoders(cdc, portSource)
encoders := DefaultEncoders(cdc, portSource, ibcPortAllocator)
for _, e := range customEncoders {
encoders = encoders.Merge(e)
}
Expand Down
16 changes: 12 additions & 4 deletions x/wasm/keeper/handler_plugin_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ type MessageEncoders struct {
Gov func(sender sdk.AccAddress, msg *wasmvmtypes.GovMsg) ([]sdk.Msg, error)
}

func DefaultEncoders(unpacker codectypes.AnyUnpacker, portSource types.ICS20TransferPortSource) MessageEncoders {
// DefaultEncoders setup the wasm module default message encoders
func DefaultEncoders(
unpacker codectypes.AnyUnpacker,
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
) MessageEncoders {
return MessageEncoders{
Bank: EncodeBankMsg,
Custom: NoCustomMsg,
Distribution: EncodeDistributionMsg,
IBC: EncodeIBCMsg(portSource),
IBC: EncodeIBCMsg(portSource, ibcPortAllocator),
Staking: EncodeStakingMsg,
Stargate: EncodeStargateMsg(unpacker),
Wasm: EncodeWasmMsg,
Expand Down Expand Up @@ -295,12 +300,15 @@ func EncodeWasmMsg(sender sdk.AccAddress, msg *wasmvmtypes.WasmMsg) ([]sdk.Msg,
}
}

func EncodeIBCMsg(portSource types.ICS20TransferPortSource) func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
func EncodeIBCMsg(
portSource types.ICS20TransferPortSource,
ibcPortAllocator IBCPortNameGenerator,
) func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
return func(ctx sdk.Context, sender sdk.AccAddress, contractIBCPortID string, msg *wasmvmtypes.IBCMsg) ([]sdk.Msg, error) {
switch {
case msg.CloseChannel != nil:
return []sdk.Msg{&channeltypes.MsgChannelCloseInit{
PortId: PortIDForContract(sender),
PortId: ibcPortAllocator.PortIDForContract(ctx, sender),
ChannelId: msg.CloseChannel.ChannelID,
Signer: sender.String(),
}}, nil
Expand Down
6 changes: 3 additions & 3 deletions x/wasm/keeper/handler_plugin_encoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
wasmvmtypes "github.com/CosmWasm/wasmvm/types"
"github.com/cosmos/gogoproto/proto"
ibctransfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" //nolint:staticcheck
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -553,7 +553,7 @@ func TestEncoding(t *testing.T) {
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
var ctx sdk.Context
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource)
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource, DefaultIBCPortNameGenerator{})
res, err := encoder.Encode(ctx, tc.sender, tc.srcContractIBCPort, tc.srcMsg)
if tc.expError {
assert.Error(t, err)
Expand Down Expand Up @@ -773,7 +773,7 @@ func TestEncodeGovMsg(t *testing.T) {
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
var ctx sdk.Context
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource)
encoder := DefaultEncoders(encodingConfig.Codec, tc.transferPortSource, DefaultIBCPortNameGenerator{})
res, gotEncErr := encoder.Encode(ctx, tc.sender, "myIBCPort", tc.srcMsg)
if tc.expError {
assert.Error(t, gotEncErr)
Expand Down
40 changes: 35 additions & 5 deletions x/wasm/keeper/ibc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package keeper

import (
"context"
"encoding/hex"
"strings"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
Expand All @@ -21,31 +23,59 @@
return k.ClaimCapability(ctx, portCap, host.PortPath(portID))
}

// ensureIbcPort is like registerIbcPort, but it checks if we already hold the port
// ensureIBCPort is like registerIbcPort, but it checks if we already hold the port
// before calling register, so this is safe to call multiple times.
// Returns success if we already registered or just registered and error if we cannot
// (lack of permissions or someone else has it)
func (k Keeper) ensureIbcPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) {
portID := PortIDForContract(contractAddr)
func (k Keeper) ensureIBCPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) {
portID := k.ibcPortNameGenerator.PortIDForContract(ctx, contractAddr)
if _, ok := k.capabilityKeeper.GetCapability(ctx, host.PortPath(portID)); ok {
return portID, nil
}
return portID, k.bindIbcPort(ctx, portID)
}

type IBCPortNameGenerator interface {
// PortIDForContract converts an address into an ibc port-id.
PortIDForContract(ctx context.Context, addr sdk.AccAddress) string
// ContractFromPortID returns the contract address for given port-id. The method does not check if the contract exists
ContractFromPortID(ctx context.Context, portID string) (sdk.AccAddress, error)
}

const portIDPrefix = "wasm."

func PortIDForContract(addr sdk.AccAddress) string {
// DefaultIBCPortNameGenerator uses Bech32 address string in port-id
type DefaultIBCPortNameGenerator struct{}

// PortIDForContract coverts contract into port-id in the format "wasm.<bech32-address>"
func (DefaultIBCPortNameGenerator) PortIDForContract(ctx context.Context, addr sdk.AccAddress) string {
return portIDPrefix + addr.String()
}

func ContractFromPortID(portID string) (sdk.AccAddress, error) {
// ContractFromPortID reads the contract address from bech32 address in the port-id.
func (DefaultIBCPortNameGenerator) ContractFromPortID(ctx context.Context, portID string) (sdk.AccAddress, error) {
if !strings.HasPrefix(portID, portIDPrefix) {
return nil, errorsmod.Wrapf(types.ErrInvalid, "without prefix")
}
return sdk.AccAddressFromBech32(portID[len(portIDPrefix):])
}

// HexIBCPortNameGenerator uses Hex address string
type HexIBCPortNameGenerator struct{}

// PortIDForContract coverts contract into port-id in the format "wasm.<hex-address>"
func (HexIBCPortNameGenerator) PortIDForContract(ctx context.Context, addr sdk.AccAddress) string {
return portIDPrefix + hex.EncodeToString(addr)

Check warning on line 68 in x/wasm/keeper/ibc.go

View check run for this annotation

Codecov / codecov/patch

x/wasm/keeper/ibc.go#L67-L68

Added lines #L67 - L68 were not covered by tests
}

// ContractFromPortID reads the contract address from hex address in the port-id.
func (HexIBCPortNameGenerator) ContractFromPortID(ctx context.Context, portID string) (sdk.AccAddress, error) {
if !strings.HasPrefix(portID, portIDPrefix) {
return nil, errorsmod.Wrapf(types.ErrInvalid, "without prefix")

Check warning on line 74 in x/wasm/keeper/ibc.go

View check run for this annotation

Codecov / codecov/patch

x/wasm/keeper/ibc.go#L72-L74

Added lines #L72 - L74 were not covered by tests
}
return sdk.AccAddressFromHexUnsafe(portID[len(portIDPrefix):])

Check warning on line 76 in x/wasm/keeper/ibc.go

View check run for this annotation

Codecov / codecov/patch

x/wasm/keeper/ibc.go#L76

Added line #L76 was not covered by tests
}

// AuthenticateCapability wraps the scopedKeeper's AuthenticateCapability function
func (k Keeper) AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) bool {
return k.capabilityKeeper.AuthenticateCapability(ctx, cap, name)
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/keeper/ibc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestBindingPortForIBCContractOnInstantiate(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, example.Contract, addr)

portID2 := PortIDForContract(addr)
portID2 := keepers.WasmKeeper.ibcPortNameGenerator.PortIDForContract(ctx, addr)
owner, _, err = keepers.IBCKeeper.PortKeeper.LookupModuleByPort(ctx, portID2)
require.NoError(t, err)
require.Equal(t, "wasm", owner)
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestContractFromPortID(t *testing.T) {
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
gotAddr, gotErr := ContractFromPortID(spec.srcPort)
gotAddr, gotErr := DefaultIBCPortNameGenerator{}.ContractFromPortID(nil, spec.srcPort)
if spec.expErr {
require.Error(t, gotErr)
return
Expand Down
12 changes: 9 additions & 3 deletions x/wasm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@

// the address capable of executing a MsgUpdateParams message. Typically, this
// should be the x/gov module account.
authority string
authority string
ibcPortNameGenerator IBCPortNameGenerator
}

func (k Keeper) getUploadAccessConfig(ctx context.Context) types.AccessConfig {
Expand Down Expand Up @@ -334,7 +335,7 @@
}
if report.HasIBCEntryPoints {
// register IBC port
ibcPort, err := k.ensureIbcPort(sdkCtx, contractAddress)
ibcPort, err := k.ensureIBCPort(sdkCtx, contractAddress)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -457,7 +458,7 @@
return nil, errorsmod.Wrap(types.ErrMigrationFailed, "requires ibc callbacks")
case report.HasIBCEntryPoints && contractInfo.IBCPortID == "":
// add ibc port
ibcPort, err := k.ensureIbcPort(sdkCtx, contractAddress)
ibcPort, err := k.ensureIBCPort(sdkCtx, contractAddress)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1194,6 +1195,11 @@
return k.importContractState(ctx, contractAddr, state)
}

// ContractFromPortID returns the contract address for given port-id. The method does not check if the contract exists
func (k Keeper) ContractFromPortID(ctx context.Context, portID string) (sdk.AccAddress, error) {
return k.ibcPortNameGenerator.ContractFromPortID(ctx, portID)

Check warning on line 1200 in x/wasm/keeper/keeper.go

View check run for this annotation

Codecov / codecov/patch

x/wasm/keeper/keeper.go#L1199-L1200

Added lines #L1199 - L1200 were not covered by tests
}

func (k Keeper) newQueryHandler(ctx sdk.Context, contractAddress sdk.AccAddress) QueryHandler {
return NewQueryHandler(ctx, k.wasmVMQueryHandler, contractAddress, k.gasRegister)
}
Expand Down
6 changes: 4 additions & 2 deletions x/wasm/keeper/keeper_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ func NewKeeper(
accountPruner: NewVestingCoinBurner(bankKeeper),
portKeeper: portKeeper,
capabilityKeeper: capabilityKeeper,
messenger: NewDefaultMessageHandler(router, ics4Wrapper, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
queryGasLimit: wasmConfig.SmartQueryGasLimit,
gasRegister: types.NewDefaultWasmGasRegister(),
maxQueryStackSize: types.DefaultMaxQueryStackSize,
Expand All @@ -56,8 +55,11 @@ func NewKeeper(
propagateGovAuthorization: map[types.AuthorizationPolicyAction]struct{}{
types.AuthZActionInstantiate: {},
},
authority: authority,
authority: authority,
ibcPortNameGenerator: DefaultIBCPortNameGenerator{},
}
keeper.messenger = NewDefaultMessageHandler(router, ics4Wrapper, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource, keeper.ibcPortNameGenerator)

keeper.wasmVMQueryHandler = DefaultQueryPlugins(bankKeeper, stakingKeeper, distrKeeper, channelKeeper, keeper)
preOpts, postOpts := splitOpts(opts)
for _, o := range preOpts {
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func TestInstantiateWithContractFactoryChildQueriesParent(t *testing.T) {
router := baseapp.NewMsgServiceRouter()
router.SetInterfaceRegistry(keepers.EncodingConfig.InterfaceRegistry)
types.RegisterMsgServer(router, NewMsgServerImpl(keeper))
keeper.messenger = NewDefaultMessageHandler(router, nil, nil, nil, nil, keepers.EncodingConfig.Codec, nil)
keeper.messenger = NewDefaultMessageHandler(router, nil, nil, nil, nil, keepers.EncodingConfig.Codec, nil, keeper.ibcPortNameGenerator)
// overwrite wasmvm in response handler
keeper.wasmVMResponseHandler = NewDefaultWasmVMContractResponseHandler(NewMessageDispatcher(keeper.messenger, keeper))

Expand Down
8 changes: 8 additions & 0 deletions x/wasm/keeper/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ func WithMaxQueryStackSize(m uint32) Option {
})
}

// WithCustomIBCPortNameGenerator overwrites the default ibc port name generator for wasm contracts with
// the custom one
func WithCustomIBCPortNameGenerator(c IBCPortNameGenerator) Option {
return optsFn(func(k *Keeper) {
k.ibcPortNameGenerator = c
})
}

// WithAcceptedAccountTypesOnContractInstantiation sets the accepted account types. Account types of this list won't be overwritten or cause a failure
// when they exist for an address on contract instantiation.
//
Expand Down
Loading