Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client/tx): simulate with correct pk (backport #18472) #18503

Merged
merged 2 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Ref: https://keepachangelog.com/en/1.0.0/

## [Unreleased]

### Bug Fixes

* (client/tx) [#18472](https://github.com/cosmos/cosmos-sdk/pull/18472) Utilizes the correct Pubkey when simulating a transaction.

## [v0.47.6](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.47.6) - 2023-11-14

### Features
Expand Down
40 changes: 29 additions & 11 deletions client/tx/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/cosmos/cosmos-sdk/client/flags"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -34,6 +35,7 @@ type Factory struct {
timeoutHeight uint64
gasAdjustment float64
chainID string
fromName string
offline bool
generateOnly bool
memo string
Expand Down Expand Up @@ -86,6 +88,7 @@ func NewFactoryCLI(clientCtx client.Context, flagSet *pflag.FlagSet) (Factory, e
accountRetriever: clientCtx.AccountRetriever,
keybase: clientCtx.Keyring,
chainID: clientCtx.ChainID,
fromName: clientCtx.FromName,
offline: clientCtx.Offline,
generateOnly: clientCtx.GenerateOnly,
gas: gasSetting.Gas,
Expand Down Expand Up @@ -414,10 +417,8 @@ func (f Factory) BuildSimTx(msgs ...sdk.Msg) ([]byte, error) {
// Create an empty signature literal as the ante handler will populate with a
// sentinel pubkey.
sig := signing.SignatureV2{
PubKey: pk,
Data: &signing.SingleSignatureData{
SignMode: f.signMode,
},
PubKey: pk,
Data: f.getSimSignatureData(pk),
Sequence: f.Sequence(),
}
if err := txb.SetSignatures(sig); err != nil {
Expand All @@ -438,16 +439,13 @@ func (f Factory) getSimPK() (cryptotypes.PubKey, error) {
pk cryptotypes.PubKey = &secp256k1.PubKey{} // use default public key type
)

// Use the first element from the list of keys in order to generate a valid
// pubkey that supports multiple algorithms.
if f.simulateAndExecute && f.keybase != nil {
records, _ := f.keybase.List()
if len(records) == 0 {
return nil, errors.New("cannot build signature for simulation, key records slice is empty")
record, err := f.keybase.Key(f.fromName)
if err != nil {
return nil, err
}

// take the first record just for simulation purposes
pk, ok = records[0].PubKey.GetCachedValue().(cryptotypes.PubKey)
pk, ok = record.PubKey.GetCachedValue().(cryptotypes.PubKey)
if !ok {
return nil, errors.New("cannot build signature for simulation, failed to convert proto Any to public key")
}
Expand All @@ -456,6 +454,26 @@ func (f Factory) getSimPK() (cryptotypes.PubKey, error) {
return pk, nil
}

// getSimSignatureData based on the pubKey type gets the correct SignatureData type
// to use for building a simulation tx.
func (f Factory) getSimSignatureData(pk cryptotypes.PubKey) signing.SignatureData {
multisigPubKey, ok := pk.(*multisig.LegacyAminoPubKey)
if !ok {
return &signing.SingleSignatureData{SignMode: f.signMode}
}

multiSignatureData := make([]signing.SignatureData, 0, multisigPubKey.Threshold)
for i := uint32(0); i < multisigPubKey.Threshold; i++ {
multiSignatureData = append(multiSignatureData, &signing.SingleSignatureData{
SignMode: f.SignMode(),
})
}

return &signing.MultiSignatureData{
Signatures: multiSignatureData,
}
}

// Prepare ensures the account defined by ctx.GetFromAddress() exists and
// if the account number and/or the account sequence number are zero (not set),
// they will be queried for and set on the provided Factory.
Expand Down
100 changes: 93 additions & 7 deletions client/tx/factory_test.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,120 @@
package tx_test
package tx

import (
"testing"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/stretchr/testify/require"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/crypto/hd"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
"github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
"github.com/cosmos/cosmos-sdk/types/tx/signing"

codectypes "github.com/cosmos/cosmos-sdk/codec/types"
cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
)

func TestFactoryPrepate(t *testing.T) {
func TestFactoryPrepare(t *testing.T) {
t.Parallel()

factory := tx.Factory{}
factory := Factory{}
clientCtx := client.Context{}

output, err := factory.Prepare(clientCtx.WithOffline(true))
require.NoError(t, err)
require.Equal(t, output, factory)

factory = tx.Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1}).WithAccountNumber(5)
factory = Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1}).WithAccountNumber(5)
output, err = factory.Prepare(clientCtx.WithFrom("foo"))
require.NoError(t, err)
require.NotEqual(t, output, factory)
require.Equal(t, output.AccountNumber(), uint64(5))
require.Equal(t, output.Sequence(), uint64(1))

factory = tx.Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1})
factory = Factory{}.WithAccountRetriever(client.MockAccountRetriever{ReturnAccNum: 10, ReturnAccSeq: 1})
output, err = factory.Prepare(clientCtx.WithFrom("foo"))
require.NoError(t, err)
require.NotEqual(t, output, factory)
require.Equal(t, output.AccountNumber(), uint64(10))
require.Equal(t, output.Sequence(), uint64(1))
}

func TestFactory_getSimPKType(t *testing.T) {
// setup keyring
registry := codectypes.NewInterfaceRegistry()
cryptocodec.RegisterInterfaces(registry)
k := keyring.NewInMemory(codec.NewProtoCodec(registry))

tests := []struct {
name string
fromName string
genKey func(fromName string, k keyring.Keyring) error
wantType types.PubKey
}{
{
name: "simple key",
fromName: "testKey",
genKey: func(fromName string, k keyring.Keyring) error {
_, err := k.NewAccount(fromName, testdata.TestMnemonic, "", "", hd.Secp256k1)
return err
},
wantType: (*secp256k1.PubKey)(nil),
},
{
name: "multisig key",
fromName: "multiKey",
genKey: func(fromName string, k keyring.Keyring) error {
pk := multisig.NewLegacyAminoPubKey(1, []types.PubKey{&multisig.LegacyAminoPubKey{}})
_, err := k.SaveMultisig(fromName, pk)
return err
},
wantType: (*multisig.LegacyAminoPubKey)(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.genKey(tt.fromName, k)
require.NoError(t, err)
f := Factory{
keybase: k,
fromName: tt.fromName,
simulateAndExecute: true,
}
got, err := f.getSimPK()
require.NoError(t, err)
require.IsType(t, tt.wantType, got)
})
}
}

func TestFactory_getSimSignatureData(t *testing.T) {
tests := []struct {
name string
pk types.PubKey
wantType any
}{
{
name: "simple pubkey",
pk: &secp256k1.PubKey{},
wantType: (*signing.SingleSignatureData)(nil),
},
{
name: "multisig pubkey",
pk: &multisig.LegacyAminoPubKey{},
wantType: (*signing.MultiSignatureData)(nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Factory{}.getSimSignatureData(tt.pk)
require.IsType(t, tt.wantType, got)
})
}
}
21 changes: 11 additions & 10 deletions client/tx/tx_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tx_test
package tx

import (
gocontext "context"
Expand All @@ -13,7 +13,6 @@ import (

"github.com/cosmos/cosmos-sdk/client"
clienttestutil "github.com/cosmos/cosmos-sdk/client/testutil"
"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/crypto/hd"
Expand Down Expand Up @@ -86,7 +85,7 @@ func TestCalculateGas(t *testing.T) {
stc := tc
txCfg, _ := newTestTxConfig(t)

txf := tx.Factory{}.
txf := Factory{}.
WithChainID("test-chain").
WithTxConfig(txCfg).WithSignMode(txCfg.SignModeHandler().DefaultMode())

Expand All @@ -95,7 +94,7 @@ func TestCalculateGas(t *testing.T) {
gasUsed: tc.args.mockGasUsed,
wantErr: tc.args.mockWantErr,
}
simRes, gotAdjusted, err := tx.CalculateGas(mockClientCtx, txf.WithGasAdjustment(stc.args.adjustment))
simRes, gotAdjusted, err := CalculateGas(mockClientCtx, txf.WithGasAdjustment(stc.args.adjustment))
if stc.expPass {
require.NoError(t, err)
require.Equal(t, simRes.GasInfo.GasUsed, stc.wantEstimate)
Expand All @@ -109,8 +108,8 @@ func TestCalculateGas(t *testing.T) {
}
}

func mockTxFactory(txCfg client.TxConfig) tx.Factory {
return tx.Factory{}.
func mockTxFactory(txCfg client.TxConfig) Factory {
return Factory{}.
WithTxConfig(txCfg).
WithAccountNumber(50).
WithSequence(23).
Expand Down Expand Up @@ -198,7 +197,7 @@ func TestMnemonicInMemo(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
txf := tx.Factory{}.
txf := Factory{}.
WithTxConfig(txConfig).
WithAccountNumber(50).
WithSequence(23).
Expand Down Expand Up @@ -269,7 +268,7 @@ func TestSign(t *testing.T) {

testCases := []struct {
name string
txf tx.Factory
txf Factory
txb client.TxBuilder
from string
overwrite bool
Expand Down Expand Up @@ -356,7 +355,7 @@ func TestSign(t *testing.T) {
var prevSigs []signingtypes.SignatureV2
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = tx.Sign(tc.txf, tc.from, tc.txb, tc.overwrite)
err = Sign(tc.txf, tc.from, tc.txb, tc.overwrite)
if len(tc.expectedPKs) == 0 {
requireT.Error(err)
} else {
Expand All @@ -372,6 +371,8 @@ func TestSign(t *testing.T) {
}

func TestPreprocessHook(t *testing.T) {
_, _, addr2 := testdata.KeyTestPubAddr()

txConfig, cdc := newTestTxConfig(t)
requireT := require.New(t)
path := hd.CreateHDPath(118, 0, 0).String()
Expand Down Expand Up @@ -420,7 +421,7 @@ func TestPreprocessHook(t *testing.T) {
msg2 := banktypes.NewMsgSend(addr2, sdk.AccAddress("to"), nil)
txb, err := txfDirect.BuildUnsignedTx(msg1, msg2)

err = tx.Sign(txfDirect, from, txb, false)
err = Sign(txfDirect, from, txb, false)
requireT.NoError(err)

// Run preprocessing
Expand Down
Loading