diff --git a/baseapp/abci_test.go b/baseapp/abci_test.go index 8a61a0aebfc2..6f954f5aaa69 100644 --- a/baseapp/abci_test.go +++ b/baseapp/abci_test.go @@ -1,4 +1,4 @@ -package baseapp +package baseapp_test import ( "fmt" @@ -9,6 +9,7 @@ import ( tmprototypes "github.com/tendermint/tendermint/proto/tendermint/types" dbm "github.com/tendermint/tm-db" + "github.com/cosmos/cosmos-sdk/baseapp" sdk "github.com/cosmos/cosmos-sdk/types" ) @@ -18,81 +19,81 @@ func TestGetBlockRentionHeight(t *testing.T) { name := t.Name() testCases := map[string]struct { - bapp *BaseApp + bapp *baseapp.BaseApp maxAgeBlocks int64 commitHeight int64 expected int64 }{ "defaults": { - bapp: NewBaseApp(name, logger, db, nil), + bapp: baseapp.NewBaseApp(name, logger, db, nil), maxAgeBlocks: 0, commitHeight: 499000, expected: 0, }, "pruning unbonding time only": { - bapp: NewBaseApp(name, logger, db, nil, SetMinRetainBlocks(1)), + bapp: baseapp.NewBaseApp(name, logger, db, nil, baseapp.SetMinRetainBlocks(1)), maxAgeBlocks: 362880, commitHeight: 499000, expected: 136120, }, "pruning iavl snapshot only": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetPruning(sdk.PruningOptions{KeepEvery: 10000}), - SetMinRetainBlocks(1), + baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}), + baseapp.SetMinRetainBlocks(1), ), maxAgeBlocks: 0, commitHeight: 499000, expected: 490000, }, "pruning state sync snapshot only": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetSnapshotInterval(50000), - SetSnapshotKeepRecent(3), - SetMinRetainBlocks(1), + baseapp.SetSnapshotInterval(50000), + baseapp.SetSnapshotKeepRecent(3), + baseapp.SetMinRetainBlocks(1), ), maxAgeBlocks: 0, commitHeight: 499000, expected: 349000, }, "pruning min retention only": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetMinRetainBlocks(400000), + baseapp.SetMinRetainBlocks(400000), ), maxAgeBlocks: 0, commitHeight: 499000, expected: 99000, }, "pruning all conditions": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetPruning(sdk.PruningOptions{KeepEvery: 10000}), - SetMinRetainBlocks(400000), - SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), + baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}), + baseapp.SetMinRetainBlocks(400000), + baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3), ), maxAgeBlocks: 362880, commitHeight: 499000, expected: 99000, }, "no pruning due to no persisted state": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetPruning(sdk.PruningOptions{KeepEvery: 10000}), - SetMinRetainBlocks(400000), - SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), + baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}), + baseapp.SetMinRetainBlocks(400000), + baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3), ), maxAgeBlocks: 362880, commitHeight: 10000, expected: 0, }, "disable pruning": { - bapp: NewBaseApp( + bapp: baseapp.NewBaseApp( name, logger, db, nil, - SetPruning(sdk.PruningOptions{KeepEvery: 10000}), - SetMinRetainBlocks(0), - SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), + baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}), + baseapp.SetMinRetainBlocks(0), + baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3), ), maxAgeBlocks: 362880, commitHeight: 499000, @@ -126,14 +127,14 @@ func TestBaseAppCreateQueryContextRejectsNegativeHeights(t *testing.T) { logger := defaultLogger() db := dbm.NewMemDB() name := t.Name() - app := NewBaseApp(name, logger, db, nil) + app := baseapp.NewBaseApp(name, logger, db, nil) proves := []bool{ false, true, } for _, prove := range proves { t.Run(fmt.Sprintf("prove=%t", prove), func(t *testing.T) { - sctx, err := app.createQueryContext(-10, true) + sctx, err := app.CreateQueryContext(-10, true) require.Error(t, err) require.Equal(t, sctx, sdk.Context{}) }) diff --git a/baseapp/baseapp_test.go b/baseapp/baseapp_test.go index e8a371ccd92e..63abd71b5633 100644 --- a/baseapp/baseapp_test.go +++ b/baseapp/baseapp_test.go @@ -1,4 +1,4 @@ -package baseapp +package baseapp_test import ( "bytes" @@ -22,6 +22,7 @@ import ( tmproto "github.com/tendermint/tendermint/proto/tendermint/types" dbm "github.com/tendermint/tm-db" + "github.com/cosmos/cosmos-sdk/baseapp" "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/snapshots" snapshottypes "github.com/cosmos/cosmos-sdk/snapshots/types" @@ -30,6 +31,7 @@ import ( "github.com/cosmos/cosmos-sdk/testutil/testdata" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" ) @@ -82,12 +84,12 @@ func defaultLogger() log.Logger { return log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "sdk/app") } -func newBaseApp(name string, options ...func(*BaseApp)) *BaseApp { +func newBaseApp(name string, options ...func(*baseapp.BaseApp)) *baseapp.BaseApp { logger := defaultLogger() db := dbm.NewMemDB() codec := codec.NewLegacyAmino() registerTestCodec(codec) - return NewBaseApp(name, logger, db, testTxDecoder(codec), options...) + return baseapp.NewBaseApp(name, logger, db, testTxDecoder(codec), options...) } func registerTestCodec(cdc *codec.LegacyAmino) { @@ -111,7 +113,7 @@ func aminoTxEncoder() sdk.TxEncoder { } // simple one store baseapp -func setupBaseApp(t *testing.T, options ...func(*BaseApp)) *BaseApp { +func setupBaseApp(t *testing.T, options ...func(*baseapp.BaseApp)) *baseapp.BaseApp { app := newBaseApp(t.Name(), options...) require.Equal(t, t.Name(), app.Name()) @@ -124,23 +126,37 @@ func setupBaseApp(t *testing.T, options ...func(*BaseApp)) *BaseApp { return app } +// testTxHandler is a tx.Handler used for the mock app, it does not +// contain any signature verification logic. +func testTxHandler(options middleware.TxHandlerOptions, customTxHandlerMiddleware handlerFun) tx.Handler { + return middleware.ComposeMiddlewares( + middleware.NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter), + middleware.GasTxMiddleware, + middleware.RecoveryTxMiddleware, + middleware.NewIndexEventsTxMiddleware(options.IndexEvents), + middleware.ValidateBasicMiddleware, + CustomTxHandlerMiddleware(customTxHandlerMiddleware), + ) +} + // simple one store baseapp with data and snapshots. Each tx is 1 MB in size (uncompressed). -func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options ...func(*BaseApp)) (*BaseApp, func()) { +func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options ...func(*baseapp.BaseApp)) (*baseapp.BaseApp, func()) { codec := codec.NewLegacyAmino() registerTestCodec(codec) - routerOpt := func(bapp *BaseApp) { + routerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() legacyRouter.AddRoute(sdk.NewRoute(routeMsgKeyValue, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { kv := msg.(*msgKeyValue) - bapp.cms.GetCommitKVStore(capKey2).Set(kv.Key, kv.Value) + bapp.CMS().GetCommitKVStore(capKey2).Set(kv.Key, kv.Value) return &sdk.Result{}, nil })) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil }, - LegacyRouter: legacyRouter, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil }, + ) bapp.SetTxHandler(txHandler) } @@ -155,9 +171,9 @@ func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options } app := setupBaseApp(t, append(options, - SetSnapshotStore(snapshotStore), - SetSnapshotInterval(snapshotInterval), - SetPruning(sdk.PruningOptions{KeepEvery: 1}), + baseapp.SetSnapshotStore(snapshotStore), + baseapp.SetSnapshotInterval(snapshotInterval), + baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 1}), routerOpt)...) app.InitChain(abci.RequestInitChain{}) @@ -208,9 +224,9 @@ func TestMountStores(t *testing.T) { app := setupBaseApp(t) // check both stores - store1 := app.cms.GetCommitKVStore(capKey1) + store1 := app.CMS().GetCommitKVStore(capKey1) require.NotNil(t, store1) - store2 := app.cms.GetCommitKVStore(capKey2) + store2 := app.CMS().GetCommitKVStore(capKey2) require.NotNil(t, store2) } @@ -218,10 +234,10 @@ func TestMountStores(t *testing.T) { // Test that LoadLatestVersion actually does. func TestLoadVersion(t *testing.T) { logger := defaultLogger() - pruningOpt := SetPruning(store.PruneNothing) + pruningOpt := baseapp.SetPruning(store.PruneNothing) db := dbm.NewMemDB() name := t.Name() - app := NewBaseApp(name, logger, db, nil, pruningOpt) + app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) // make a cap key and mount the store err := app.LoadLatestVersion() // needed to make stores non-nil @@ -248,7 +264,7 @@ func TestLoadVersion(t *testing.T) { commitID2 := sdk.CommitID{Version: 2, Hash: res.Data} // reload with LoadLatestVersion - app = NewBaseApp(name, logger, db, nil, pruningOpt) + app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) app.MountStores() err = app.LoadLatestVersion() require.Nil(t, err) @@ -256,7 +272,7 @@ func TestLoadVersion(t *testing.T) { // reload with LoadVersion, see if you can commit the same block and get // the same result - app = NewBaseApp(name, logger, db, nil, pruningOpt) + app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) err = app.LoadVersion(1) require.Nil(t, err) testLoadVersionHelper(t, app, int64(1), commitID1) @@ -265,8 +281,8 @@ func TestLoadVersion(t *testing.T) { testLoadVersionHelper(t, app, int64(2), commitID2) } -func useDefaultLoader(app *BaseApp) { - app.SetStoreLoader(DefaultStoreLoader) +func useDefaultLoader(app *baseapp.BaseApp) { + app.SetStoreLoader(baseapp.DefaultStoreLoader) } func initStore(t *testing.T, db dbm.DB, storeKey string, k, v []byte) { @@ -305,7 +321,7 @@ func checkStore(t *testing.T, db dbm.DB, ver int64, storeKey string, k, v []byte // Test that LoadLatestVersion actually does. func TestSetLoader(t *testing.T) { cases := map[string]struct { - setLoader func(*BaseApp) + setLoader func(*baseapp.BaseApp) origStoreKey string loadStoreKey string }{ @@ -331,11 +347,11 @@ func TestSetLoader(t *testing.T) { initStore(t, db, tc.origStoreKey, k, v) // load the app with the existing db - opts := []func(*BaseApp){SetPruning(store.PruneNothing)} + opts := []func(*baseapp.BaseApp){baseapp.SetPruning(store.PruneNothing)} if tc.setLoader != nil { opts = append(opts, tc.setLoader) } - app := NewBaseApp(t.Name(), defaultLogger(), db, nil, opts...) + app := baseapp.NewBaseApp(t.Name(), defaultLogger(), db, nil, opts...) app.MountStores(sdk.NewKVStoreKey(tc.loadStoreKey)) err := app.LoadLatestVersion() require.Nil(t, err) @@ -354,10 +370,10 @@ func TestSetLoader(t *testing.T) { func TestVersionSetterGetter(t *testing.T) { logger := defaultLogger() - pruningOpt := SetPruning(store.PruneDefault) + pruningOpt := baseapp.SetPruning(store.PruneDefault) db := dbm.NewMemDB() name := t.Name() - app := NewBaseApp(name, logger, db, nil, pruningOpt) + app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) require.Equal(t, "", app.Version()) res := app.Query(abci.RequestQuery{Path: "app/version"}) @@ -374,10 +390,10 @@ func TestVersionSetterGetter(t *testing.T) { func TestLoadVersionInvalid(t *testing.T) { logger := log.NewNopLogger() - pruningOpt := SetPruning(store.PruneNothing) + pruningOpt := baseapp.SetPruning(store.PruneNothing) db := dbm.NewMemDB() name := t.Name() - app := NewBaseApp(name, logger, db, nil, pruningOpt) + app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) err := app.LoadLatestVersion() require.Nil(t, err) @@ -392,7 +408,7 @@ func TestLoadVersionInvalid(t *testing.T) { commitID1 := sdk.CommitID{Version: 1, Hash: res.Data} // create a new app with the stores mounted under the same cap key - app = NewBaseApp(name, logger, db, nil, pruningOpt) + app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) // require we can load the latest version err = app.LoadVersion(1) @@ -411,10 +427,10 @@ func TestLoadVersionPruning(t *testing.T) { KeepEvery: 3, Interval: 1, } - pruningOpt := SetPruning(pruningOptions) + pruningOpt := baseapp.SetPruning(pruningOptions) db := dbm.NewMemDB() name := t.Name() - app := NewBaseApp(name, logger, db, nil, pruningOpt) + app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) // make a cap key and mount the store capKey := sdk.NewKVStoreKey("key1") @@ -442,17 +458,17 @@ func TestLoadVersionPruning(t *testing.T) { } for _, v := range []int64{1, 2, 4} { - _, err = app.cms.CacheMultiStoreWithVersion(v) + _, err = app.CMS().CacheMultiStoreWithVersion(v) require.NoError(t, err) } for _, v := range []int64{3, 5, 6, 7} { - _, err = app.cms.CacheMultiStoreWithVersion(v) + _, err = app.CMS().CacheMultiStoreWithVersion(v) require.NoError(t, err) } // reload with LoadLatestVersion, check it loads last version - app = NewBaseApp(name, logger, db, nil, pruningOpt) + app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt) app.MountStores(capKey) err = app.LoadLatestVersion() @@ -460,7 +476,7 @@ func TestLoadVersionPruning(t *testing.T) { testLoadVersionHelper(t, app, int64(7), lastCommitID) } -func testLoadVersionHelper(t *testing.T, app *BaseApp, expectedHeight int64, expectedID sdk.CommitID) { +func testLoadVersionHelper(t *testing.T, app *baseapp.BaseApp, expectedHeight int64, expectedID sdk.CommitID) { lastHeight := app.LastBlockHeight() lastID := app.LastCommitID() require.Equal(t, expectedHeight, lastHeight) @@ -470,13 +486,13 @@ func testLoadVersionHelper(t *testing.T, app *BaseApp, expectedHeight int64, exp func TestOptionFunction(t *testing.T) { logger := defaultLogger() db := dbm.NewMemDB() - bap := NewBaseApp("starting name", logger, db, nil, testChangeNameHelper("new name")) - require.Equal(t, bap.name, "new name", "BaseApp should have had name changed via option function") + bap := baseapp.NewBaseApp("starting name", logger, db, nil, testChangeNameHelper("new name")) + require.Equal(t, bap.GetName(), "new name", "BaseApp should have had name changed via option function") } -func testChangeNameHelper(name string) func(*BaseApp) { - return func(bap *BaseApp) { - bap.name = name +func testChangeNameHelper(name string) func(*baseapp.BaseApp) { + return func(bap *baseapp.BaseApp) { + bap.SetName(name) } } @@ -490,7 +506,7 @@ func TestTxDecoder(t *testing.T) { tx := newTxCounter(1, 0) txBytes := codec.MustMarshal(tx) - dTx, err := app.txDecoder(txBytes) + dTx, err := app.TxDecoder(txBytes) require.NoError(t, err) cTx := dTx.(txTest) @@ -555,8 +571,8 @@ func TestBaseAppOptionSeal(t *testing.T) { func TestSetMinGasPrices(t *testing.T) { minGasPrices := sdk.DecCoins{sdk.NewInt64DecCoin("stake", 5000)} - app := newBaseApp(t.Name(), SetMinGasPrices(minGasPrices.String())) - require.Equal(t, minGasPrices, app.minGasPrices) + app := newBaseApp(t.Name(), baseapp.SetMinGasPrices(minGasPrices.String())) + require.Equal(t, minGasPrices, app.MinGasPrices()) } func TestInitChainer(t *testing.T) { @@ -565,7 +581,7 @@ func TestInitChainer(t *testing.T) { // we can reload the same app later db := dbm.NewMemDB() logger := defaultLogger() - app := NewBaseApp(name, logger, db, nil) + app := baseapp.NewBaseApp(name, logger, db, nil) capKey := sdk.NewKVStoreKey("main") capKey2 := sdk.NewKVStoreKey("key2") app.MountStores(capKey, capKey2) @@ -608,10 +624,10 @@ func TestInitChainer(t *testing.T) { ) // assert that chainID is set correctly in InitChain - chainID := app.deliverState.ctx.ChainID() + chainID := app.DeliverState().Context().ChainID() require.Equal(t, "test-chain-id", chainID, "ChainID in deliverState not set correctly in InitChain") - chainID = app.checkState.ctx.ChainID() + chainID = app.CheckState().Context().ChainID() require.Equal(t, "test-chain-id", chainID, "ChainID in checkState not set correctly in InitChain") app.Commit() @@ -620,7 +636,7 @@ func TestInitChainer(t *testing.T) { require.Equal(t, value, res.Value) // reload app - app = NewBaseApp(name, logger, db, nil) + app = baseapp.NewBaseApp(name, logger, db, nil) app.SetInitChainer(initChainer) app.MountStores(capKey, capKey2) err = app.LoadLatestVersion() // needed to make stores non-nil @@ -644,7 +660,7 @@ func TestInitChain_WithInitialHeight(t *testing.T) { name := t.Name() db := dbm.NewMemDB() logger := defaultLogger() - app := NewBaseApp(name, logger, db, nil) + app := baseapp.NewBaseApp(name, logger, db, nil) app.InitChain( abci.RequestInitChain{ @@ -660,7 +676,7 @@ func TestBeginBlock_WithInitialHeight(t *testing.T) { name := t.Name() db := dbm.NewMemDB() logger := defaultLogger() - app := NewBaseApp(name, logger, db, nil) + app := baseapp.NewBaseApp(name, logger, db, nil) app.InitChain( abci.RequestInitChain{ @@ -711,6 +727,9 @@ func (tx txTest) ValidateBasic() error { return nil } // Implements GasTx func (tx txTest) GetGas() uint64 { return tx.GasLimit } +// Implements TxWithTimeoutHeight +func (tx txTest) GetTimeoutHeight() uint64 { return 0 } + const ( routeMsgCounter = "msgCounter" routeMsgCounter2 = "msgCounter2" @@ -826,7 +845,7 @@ func testTxDecoder(cdc *codec.LegacyAmino) sdk.TxDecoder { } } -func anteHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) sdk.AnteHandler { +func customHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) handlerFun { return func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { store := ctx.KVStore(capKey) txTest := tx.(txTest) @@ -841,7 +860,7 @@ func anteHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) sdk.A } ctx.EventManager().EmitEvents( - counterEvent("ante_handler", txTest.Counter), + counterEvent("post_handlers", txTest.Counter), ) return ctx, nil @@ -929,18 +948,19 @@ func TestCheckTx(t *testing.T) { // This ensures changes to the kvstore persist across successive CheckTx. counterKey := []byte("counter-key") - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() // TODO: can remove this once CheckTx doesnt process msgs. legacyRouter.AddRoute(sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { return &sdk.Result{}, nil })) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: anteHandlerTxTest(t, capKey1, counterKey), - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + customHandlerTxTest(t, capKey1, counterKey), + ) bapp.SetTxHandler(txHandler) } @@ -962,23 +982,23 @@ func TestCheckTx(t *testing.T) { require.True(t, r.IsOK(), fmt.Sprintf("%v", r)) } - checkStateStore := app.checkState.ctx.KVStore(capKey1) + checkStateStore := app.CheckState().Context().KVStore(capKey1) storedCounter := getIntFromStore(checkStateStore, counterKey) - // Ensure AnteHandler ran + // Ensure storedCounter require.Equal(t, nTxs, storedCounter) // If a block is committed, CheckTx state should be reset. header := tmproto.Header{Height: 1} app.BeginBlock(abci.RequestBeginBlock{Header: header, Hash: []byte("hash")}) - require.NotNil(t, app.checkState.ctx.BlockGasMeter(), "block gas meter should have been set to checkState") - require.NotEmpty(t, app.checkState.ctx.HeaderHash()) + require.NotNil(t, app.CheckState().Context().BlockGasMeter(), "block gas meter should have been set to checkState") + require.NotEmpty(t, app.CheckState().Context().HeaderHash()) app.EndBlock(abci.RequestEndBlock{}) app.Commit() - checkStateStore = app.checkState.ctx.KVStore(capKey1) + checkStateStore = app.CheckState().Context().KVStore(capKey1) storedBytes := checkStateStore.Get(counterKey) require.Nil(t, storedBytes) } @@ -986,20 +1006,21 @@ func TestCheckTx(t *testing.T) { // Test that successive DeliverTx can see each others' effects // on the store, both within and across blocks. func TestDeliverTx(t *testing.T) { - // test increments in the ante + // test increments in the post txHandler anteKey := []byte("ante-key") // test increments in the handler deliverKey := []byte("deliver-key") - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey), - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + customHandlerTxTest(t, capKey1, anteKey), + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1027,7 +1048,7 @@ func TestDeliverTx(t *testing.T) { require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) events := res.GetEvents() require.Len(t, events, 3, "should contain ante handler, message type and counter events respectively") - require.Equal(t, sdk.MarkEventsToIndex(counterEvent("ante_handler", counter).ToABCIEvents(), map[string]struct{}{})[0], events[0], "ante handler event") + require.Equal(t, sdk.MarkEventsToIndex(counterEvent("post_handlers", counter).ToABCIEvents(), map[string]struct{}{})[0], events[0], "ante handler event") require.Equal(t, sdk.MarkEventsToIndex(counterEvent(sdk.EventTypeMessage, counter).ToABCIEvents(), map[string]struct{}{})[0], events[2], "msg handler update counter event") } @@ -1049,18 +1070,19 @@ func TestMultiMsgDeliverTx(t *testing.T) { // increment the msg counter deliverKey := []byte("deliver-key") deliverKey2 := []byte("deliver-key2") - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r1 := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) r2 := sdk.NewRoute(routeMsgCounter2, handlerMsgCounter(t, capKey1, deliverKey2)) legacyRouter.AddRoute(r1) legacyRouter.AddRoute(r2) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey), - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + customHandlerTxTest(t, capKey1, anteKey), + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1080,7 +1102,7 @@ func TestMultiMsgDeliverTx(t *testing.T) { res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) - store := app.deliverState.ctx.KVStore(capKey1) + store := app.DeliverState().Context().KVStore(capKey1) // tx counter only incremented once txCounter := getIntFromStore(store, anteKey) @@ -1100,7 +1122,7 @@ func TestMultiMsgDeliverTx(t *testing.T) { res = app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) - store = app.deliverState.ctx.KVStore(capKey1) + store = app.DeliverState().Context().KVStore(capKey1) // tx counter only incremented once txCounter = getIntFromStore(store, anteKey) @@ -1127,19 +1149,20 @@ func TestConcurrentCheckDeliver(t *testing.T) { func TestSimulateTx(t *testing.T) { gasConsumed := uint64(5) - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { ctx.GasMeter().ConsumeGas(gasConsumed, "test") return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil }, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil }, + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1195,20 +1218,21 @@ func TestSimulateTx(t *testing.T) { } func TestRunInvalidTransaction(t *testing.T) { - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { return }, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1309,7 +1333,7 @@ func TestTxGasLimits(t *testing.T) { return ctx, nil } - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { count := msg.(msgCounter).Counter @@ -1317,12 +1341,14 @@ func TestTxGasLimits(t *testing.T) { return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: ante, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + ante, + ) + bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1386,7 +1412,7 @@ func TestMaxBlockGasLimits(t *testing.T) { return ctx, nil } - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { count := msg.(msgCounter).Counter @@ -1394,12 +1420,13 @@ func TestMaxBlockGasLimits(t *testing.T) { return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: ante, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + ante, + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1443,7 +1470,7 @@ func TestMaxBlockGasLimits(t *testing.T) { for j := 0; j < tc.numDelivers; j++ { _, result, err := app.SimDeliver(aminoTxEncoder(), tx) - ctx := app.getState(runTxModeDeliver).ctx + ctx := app.DeliverState().Context() // check for failed transactions if tc.fail && (j+1) > tc.failAfterDeliver { @@ -1470,21 +1497,22 @@ func TestMaxBlockGasLimits(t *testing.T) { } } -func TestBaseAppAnteHandler(t *testing.T) { +func TestBaseAppMiddleware(t *testing.T) { anteKey := []byte("ante-key") deliverKey := []byte("deliver-key") cdc := codec.NewLegacyAmino() - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey), - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + customHandlerTxTest(t, capKey1, anteKey), + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1498,7 +1526,7 @@ func TestBaseAppAnteHandler(t *testing.T) { // execute a tx that will fail ante handler execution // // NOTE: State should not be mutated here. This will be implicitly checked by - // the next txs ante handler execution (anteHandlerTxTest). + // the next txs ante handler execution (customHandlerTxTest). tx := newTxCounter(0, 0) tx.setFailOnAnte(true) txBytes, err := cdc.Marshal(tx) @@ -1507,7 +1535,7 @@ func TestBaseAppAnteHandler(t *testing.T) { require.Empty(t, res.Events) require.False(t, res.IsOK(), fmt.Sprintf("%v", res)) - ctx := app.getState(runTxModeDeliver).ctx + ctx := app.DeliverState().Context() store := ctx.KVStore(capKey1) require.Equal(t, int64(0), getIntFromStore(store, anteKey)) @@ -1523,7 +1551,7 @@ func TestBaseAppAnteHandler(t *testing.T) { require.Empty(t, res.Events) require.False(t, res.IsOK(), fmt.Sprintf("%v", res)) - ctx = app.getState(runTxModeDeliver).ctx + ctx = app.DeliverState().Context() store = ctx.KVStore(capKey1) require.Equal(t, int64(1), getIntFromStore(store, anteKey)) require.Equal(t, int64(0), getIntFromStore(store, deliverKey)) @@ -1539,7 +1567,7 @@ func TestBaseAppAnteHandler(t *testing.T) { require.NotEmpty(t, res.Events) require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) - ctx = app.getState(runTxModeDeliver).ctx + ctx = app.DeliverState().Context() store = ctx.KVStore(capKey1) require.Equal(t, int64(2), getIntFromStore(store, anteKey)) require.Equal(t, int64(1), getIntFromStore(store, deliverKey)) @@ -1564,7 +1592,7 @@ func TestGasConsumptionBadTx(t *testing.T) { cdc := codec.NewLegacyAmino() registerTestCodec(cdc) - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { count := msg.(msgCounter).Counter @@ -1572,12 +1600,13 @@ func TestGasConsumptionBadTx(t *testing.T) { return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: ante, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + ante, + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1617,7 +1646,7 @@ func TestGasConsumptionBadTx(t *testing.T) { func TestQuery(t *testing.T) { key, value := []byte("hello"), []byte("goodbye") - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { legacyRouter := middleware.NewLegacyRouter() r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { store := ctx.KVStore(capKey1) @@ -1625,16 +1654,17 @@ func TestQuery(t *testing.T) { return &sdk.Result{}, nil }) legacyRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), + }, + func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { store := ctx.KVStore(capKey1) store.Set(key, value) return }, - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + ) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1678,7 +1708,7 @@ func TestQuery(t *testing.T) { } func TestGRPCQuery(t *testing.T) { - grpcQueryOpt := func(bapp *BaseApp) { + grpcQueryOpt := func(bapp *baseapp.BaseApp) { testdata.RegisterQueryServer( bapp.GRPCQueryRouter(), testdata.QueryImpl{}, @@ -1713,14 +1743,14 @@ func TestGRPCQuery(t *testing.T) { // Test p2p filter queries func TestP2PQuery(t *testing.T) { - addrPeerFilterOpt := func(bapp *BaseApp) { + addrPeerFilterOpt := func(bapp *baseapp.BaseApp) { bapp.SetAddrPeerFilter(func(addrport string) abci.ResponseQuery { require.Equal(t, "1.1.1.1:8000", addrport) return abci.ResponseQuery{Code: uint32(3)} }) } - idPeerFilterOpt := func(bapp *BaseApp) { + idPeerFilterOpt := func(bapp *baseapp.BaseApp) { bapp.SetIDPeerFilter(func(id string) abci.ResponseQuery { require.Equal(t, "testid", id) return abci.ResponseQuery{Code: uint32(4)} @@ -1748,16 +1778,16 @@ func TestGetMaximumBlockGas(t *testing.T) { ctx := app.NewContext(true, tmproto.Header{}) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 0}}) - require.Equal(t, uint64(0), app.getMaximumBlockGas(ctx)) + require.Equal(t, uint64(0), app.GetMaximumBlockGas(ctx)) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -1}}) - require.Equal(t, uint64(0), app.getMaximumBlockGas(ctx)) + require.Equal(t, uint64(0), app.GetMaximumBlockGas(ctx)) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 5000000}}) - require.Equal(t, uint64(5000000), app.getMaximumBlockGas(ctx)) + require.Equal(t, uint64(5000000), app.GetMaximumBlockGas(ctx)) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -5000000}}) - require.Panics(t, func() { app.getMaximumBlockGas(ctx) }) + require.Panics(t, func() { app.GetMaximumBlockGas(ctx) }) } func TestListSnapshots(t *testing.T) { @@ -1940,21 +1970,14 @@ func (rtr *testCustomRouter) Route(ctx sdk.Context, path string) sdk.Handler { } func TestWithRouter(t *testing.T) { - // test increments in the ante - anteKey := []byte("ante-key") // test increments in the handler deliverKey := []byte("deliver-key") - txHandlerOpt := func(bapp *BaseApp) { + txHandlerOpt := func(bapp *baseapp.BaseApp) { customRouter := &testCustomRouter{routes: sync.Map{}} r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) customRouter.AddRoute(r) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: customRouter, - LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey), - MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), - }) - require.NoError(t, err) + txHandler := middleware.NewRunMsgsTxHandler(middleware.NewMsgServiceRouter(interfaceRegistry), customRouter) bapp.SetTxHandler(txHandler) } app := setupBaseApp(t, txHandlerOpt) @@ -1998,7 +2021,7 @@ func TestBaseApp_EndBlock(t *testing.T) { }, } - app := NewBaseApp(name, logger, db, nil) + app := baseapp.NewBaseApp(name, logger, db, nil) app.SetParamStore(¶mStore{db: dbm.NewMemDB()}) app.InitChain(abci.RequestInitChain{ ConsensusParams: cp, diff --git a/baseapp/custom_txhandler_test.go b/baseapp/custom_txhandler_test.go new file mode 100644 index 000000000000..6582dda66184 --- /dev/null +++ b/baseapp/custom_txhandler_test.go @@ -0,0 +1,117 @@ +package baseapp_test + +import ( + "context" + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" + abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/crypto/tmhash" +) + +type handlerFun func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) + +type customTxHandler struct { + handler handlerFun + next tx.Handler +} + +var _ tx.Handler = customTxHandler{} + +// CustomTxMiddleware is being used in tests for testing +// custom pre-`runMsgs` logic (also called antehandlers before). +func CustomTxHandlerMiddleware(handler handlerFun) tx.Middleware { + return func(txHandler tx.Handler) tx.Handler { + return customTxHandler{ + handler: handler, + next: txHandler, + } + } +} + +// CheckTx implements tx.Handler.CheckTx method. +func (txh customTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false) + if err != nil { + return abci.ResponseCheckTx{}, err + } + + return txh.next.CheckTx(sdk.WrapSDKContext(sdkCtx), tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx method. +func (txh customTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false) + if err != nil { + return abci.ResponseDeliverTx{}, err + } + + return txh.next.DeliverTx(sdk.WrapSDKContext(sdkCtx), tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx method. +func (txh customTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + sdkCtx, err := txh.runHandler(ctx, sdkTx, req.TxBytes, true) + if err != nil { + return tx.ResponseSimulateTx{}, err + } + + return txh.next.SimulateTx(sdk.WrapSDKContext(sdkCtx), sdkTx, req) +} + +func (txh customTxHandler) runHandler(ctx context.Context, tx sdk.Tx, txBytes []byte, isSimulate bool) (sdk.Context, error) { + sdkCtx := sdk.UnwrapSDKContext(ctx) + if txh.handler == nil { + return sdkCtx, nil + } + + ms := sdkCtx.MultiStore() + + // Branch context before Handler call in case it aborts. + // This is required for both CheckTx and DeliverTx. + // Ref: https://github.com/cosmos/cosmos-sdk/issues/2772 + // + // NOTE: Alternatively, we could require that Handler ensures that + // writes do not happen if aborted/failed. This may have some + // performance benefits, but it'll be more difficult to get right. + cacheCtx, msCache := cacheTxContext(sdkCtx, txBytes) + cacheCtx = cacheCtx.WithEventManager(sdk.NewEventManager()) + newCtx, err := txh.handler(cacheCtx, tx, isSimulate) + if err != nil { + return sdk.Context{}, err + } + + if !newCtx.IsZero() { + // At this point, newCtx.MultiStore() is a store branch, or something else + // replaced by the Handler. We want the original multistore. + // + // Also, in the case of the tx aborting, we need to track gas consumed via + // the instantiated gas meter in the Handler, so we update the context + // prior to returning. + sdkCtx = newCtx.WithMultiStore(ms) + } + + msCache.Write() + + return sdkCtx, nil +} + +// cacheTxContext returns a new context based off of the provided context with +// a branched multi-store. +func cacheTxContext(sdkCtx sdk.Context, txBytes []byte) (sdk.Context, sdk.CacheMultiStore) { + ms := sdkCtx.MultiStore() + // TODO: https://github.com/cosmos/cosmos-sdk/issues/2824 + msCache := ms.CacheMultiStore() + if msCache.TracingEnabled() { + msCache = msCache.SetTracingContext( + sdk.TraceContext( + map[string]interface{}{ + "txHash": fmt.Sprintf("%X", tmhash.Sum(txBytes)), + }, + ), + ).(sdk.CacheMultiStore) + } + + return sdkCtx.WithMultiStore(msCache), msCache +} diff --git a/baseapp/queryrouter_test.go b/baseapp/queryrouter_test.go index c7637f17000e..4b38f6458641 100644 --- a/baseapp/queryrouter_test.go +++ b/baseapp/queryrouter_test.go @@ -1,4 +1,4 @@ -package baseapp +package baseapp_test import ( "testing" @@ -7,6 +7,7 @@ import ( abci "github.com/tendermint/tendermint/abci/types" + "github.com/cosmos/cosmos-sdk/baseapp" sdk "github.com/cosmos/cosmos-sdk/types" ) @@ -15,7 +16,7 @@ var testQuerier = func(_ sdk.Context, _ []string, _ abci.RequestQuery) ([]byte, } func TestQueryRouter(t *testing.T) { - qr := NewQueryRouter() + qr := baseapp.NewQueryRouter() // require panic on invalid route require.Panics(t, func() { diff --git a/baseapp/util_test.go b/baseapp/util_test.go new file mode 100644 index 000000000000..5f7504af85ec --- /dev/null +++ b/baseapp/util_test.go @@ -0,0 +1,67 @@ +package baseapp + +import ( + "github.com/cosmos/cosmos-sdk/types" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// TODO: Can be removed once we move all middleware tests into x/auth/middleware +// ref: #https://github.com/cosmos/cosmos-sdk/issues/10282 + +// CheckState is an exported method to be able to access baseapp's +// checkState in tests. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) CheckState() *state { + return app.checkState +} + +// DeliverState is an exported method to be able to access baseapp's +// deliverState in tests. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) DeliverState() *state { + return app.deliverState +} + +// CMS is an exported method to be able to access baseapp's cms in tests. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) CMS() types.CommitMultiStore { + return app.cms +} + +// GetMaximumBlockGas return maximum blocks gas. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) GetMaximumBlockGas(ctx sdk.Context) uint64 { + return app.getMaximumBlockGas(ctx) +} + +// GetName return name. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) GetName() string { + return app.name +} + +// GetName return name. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) TxDecoder(txBytes []byte) (sdk.Tx, error) { + return app.txDecoder(txBytes) +} + +// CreateQueryContext calls app's createQueryContext. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) CreateQueryContext(height int64, prove bool) (sdk.Context, error) { + return app.createQueryContext(height, prove) +} + +// MinGasPrices returns minGasPrices. +// +// This method is only accessible in baseapp tests. +func (app *BaseApp) MinGasPrices() sdk.DecCoins { + return app.minGasPrices +} diff --git a/contrib/rosetta/configuration/bootstrap.json b/contrib/rosetta/configuration/bootstrap.json index 6fbfac1a509a..ad30b1611e2a 100644 --- a/contrib/rosetta/configuration/bootstrap.json +++ b/contrib/rosetta/configuration/bootstrap.json @@ -1,7 +1,7 @@ [ { "account_identifier": { - "address":"cosmos1y3awd3vl7g29q44uvz0yrevcduf2exvkwxk3uq" + "address":"cosmos1wy36cv7hveh7xt4ushy2twp5czqxnz5v6rn3xw" }, "currency":{ "symbol":"stake", diff --git a/contrib/rosetta/rosetta-ci/data.tar.gz b/contrib/rosetta/rosetta-ci/data.tar.gz index b3b890e11538..7b9a99a6db0f 100644 Binary files a/contrib/rosetta/rosetta-ci/data.tar.gz and b/contrib/rosetta/rosetta-ci/data.tar.gz differ diff --git a/server/mock/app.go b/server/mock/app.go index b1fe740a1936..d5f7c911c252 100644 --- a/server/mock/app.go +++ b/server/mock/app.go @@ -15,9 +15,19 @@ import ( "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/simapp" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/x/auth/middleware" ) +func testTxHandler(options middleware.TxHandlerOptions) tx.Handler { + return middleware.ComposeMiddlewares( + middleware.NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter), + middleware.GasTxMiddleware, + middleware.RecoveryTxMiddleware, + middleware.NewIndexEventsTxMiddleware(options.IndexEvents), + ) +} + // NewApp creates a simple mock kvstore app for testing. It should work // similar to a real app. Make sure rootDir is empty before running the test, // in order to guarantee consistent results @@ -44,13 +54,12 @@ func NewApp(rootDir string, logger log.Logger) (abci.Application, error) { // We're adding a test legacy route here, which accesses the kvstore // and simply sets the Msg's key/value pair in the kvstore. legacyRouter.AddRoute(sdk.NewRoute("kvstore", KVStoreHandler(capKeyMainStore))) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - LegacyRouter: legacyRouter, - MsgServiceRouter: middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry), - }) - if err != nil { - return nil, err - } + txHandler := testTxHandler( + middleware.TxHandlerOptions{ + LegacyRouter: legacyRouter, + MsgServiceRouter: middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry), + }, + ) baseApp.SetTxHandler(txHandler) // Load latest version. diff --git a/simapp/app.go b/simapp/app.go index f53df16bf462..b3323ec076f4 100644 --- a/simapp/app.go +++ b/simapp/app.go @@ -30,7 +30,6 @@ import ( "github.com/cosmos/cosmos-sdk/types/module" "github.com/cosmos/cosmos-sdk/version" "github.com/cosmos/cosmos-sdk/x/auth" - "github.com/cosmos/cosmos-sdk/x/auth/ante" authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" authsims "github.com/cosmos/cosmos-sdk/x/auth/simulation" authtx "github.com/cosmos/cosmos-sdk/x/auth/tx" @@ -401,29 +400,20 @@ func NewSimApp( } func (app *SimApp) setTxHandler(txConfig client.TxConfig, indexEventsStr []string) { - anteHandler, err := ante.NewAnteHandler( - ante.HandlerOptions{ - AccountKeeper: app.AccountKeeper, - BankKeeper: app.BankKeeper, - SignModeHandler: txConfig.SignModeHandler(), - FeegrantKeeper: app.FeeGrantKeeper, - SigGasConsumer: ante.DefaultSigVerificationGasConsumer, - }, - ) - if err != nil { - panic(err) - } - indexEvents := map[string]struct{}{} for _, e := range indexEventsStr { indexEvents[e] = struct{}{} } txHandler, err := authmiddleware.NewDefaultTxHandler(authmiddleware.TxHandlerOptions{ - Debug: app.Trace(), - IndexEvents: indexEvents, - LegacyRouter: app.legacyRouter, - MsgServiceRouter: app.msgSvcRouter, - LegacyAnteHandler: anteHandler, + Debug: app.Trace(), + IndexEvents: indexEvents, + LegacyRouter: app.legacyRouter, + MsgServiceRouter: app.msgSvcRouter, + AccountKeeper: app.AccountKeeper, + BankKeeper: app.BankKeeper, + FeegrantKeeper: app.FeeGrantKeeper, + SignModeHandler: txConfig.SignModeHandler(), + SigGasConsumer: authmiddleware.DefaultSigVerificationGasConsumer, }) if err != nil { panic(err) diff --git a/x/auth/ante/ante.go b/x/auth/ante/ante.go deleted file mode 100644 index dbb40aeb13ce..000000000000 --- a/x/auth/ante/ante.go +++ /dev/null @@ -1,57 +0,0 @@ -package ante - -import ( - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/types/tx/signing" - authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" - "github.com/cosmos/cosmos-sdk/x/auth/types" -) - -// HandlerOptions are the options required for constructing a default SDK AnteHandler. -type HandlerOptions struct { - AccountKeeper AccountKeeper - BankKeeper types.BankKeeper - FeegrantKeeper FeegrantKeeper - SignModeHandler authsigning.SignModeHandler - SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error -} - -// NewAnteHandler returns an AnteHandler that checks and increments sequence -// numbers, checks signatures & account numbers, and deducts fees from the first -// signer. -func NewAnteHandler(options HandlerOptions) (sdk.AnteHandler, error) { - if options.AccountKeeper == nil { - return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "account keeper is required for ante builder") - } - - if options.BankKeeper == nil { - return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "bank keeper is required for ante builder") - } - - if options.SignModeHandler == nil { - return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "sign mode handler is required for ante builder") - } - - var sigGasConsumer = options.SigGasConsumer - if sigGasConsumer == nil { - sigGasConsumer = DefaultSigVerificationGasConsumer - } - - anteDecorators := []sdk.AnteDecorator{ - NewRejectExtensionOptionsDecorator(), - NewMempoolFeeDecorator(), - NewValidateBasicDecorator(), - NewTxTimeoutHeightDecorator(), - NewValidateMemoDecorator(options.AccountKeeper), - NewConsumeGasForTxSizeDecorator(options.AccountKeeper), - NewDeductFeeDecorator(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper), - NewSetPubKeyDecorator(options.AccountKeeper), // SetPubKeyDecorator must be called before all signature verification decorators - NewValidateSigCountDecorator(options.AccountKeeper), - NewSigGasConsumeDecorator(options.AccountKeeper, sigGasConsumer), - NewSigVerificationDecorator(options.AccountKeeper, options.SignModeHandler), - NewIncrementSequenceDecorator(options.AccountKeeper), - } - - return sdk.ChainAnteDecorators(anteDecorators...), nil -} diff --git a/x/auth/ante/basic.go b/x/auth/ante/basic.go deleted file mode 100644 index d42aed214444..000000000000 --- a/x/auth/ante/basic.go +++ /dev/null @@ -1,207 +0,0 @@ -package ante - -import ( - "github.com/cosmos/cosmos-sdk/codec/legacy" - "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" - authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" -) - -// ValidateBasicDecorator will call tx.ValidateBasic, msg.ValidateBasic(for each msg inside tx) -// and return any non-nil error. -// If ValidateBasic passes, decorator calls next AnteHandler in chain. Note, -// ValidateBasicDecorator decorator will not get executed on ReCheckTx since it -// is not dependent on application state. -type ValidateBasicDecorator struct{} - -func NewValidateBasicDecorator() ValidateBasicDecorator { - return ValidateBasicDecorator{} -} - -func (vbd ValidateBasicDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - // no need to validate basic on recheck tx, call next antehandler - if ctx.IsReCheckTx() { - return next(ctx, tx, simulate) - } - - if err := tx.ValidateBasic(); err != nil { - return ctx, err - } - - return next(ctx, tx, simulate) -} - -// ValidateMemoDecorator will validate memo given the parameters passed in -// If memo is too large decorator returns with error, otherwise call next AnteHandler -// CONTRACT: Tx must implement TxWithMemo interface -type ValidateMemoDecorator struct { - ak AccountKeeper -} - -func NewValidateMemoDecorator(ak AccountKeeper) ValidateMemoDecorator { - return ValidateMemoDecorator{ - ak: ak, - } -} - -func (vmd ValidateMemoDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - memoTx, ok := tx.(sdk.TxWithMemo) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") - } - - params := vmd.ak.GetParams(ctx) - - memoLength := len(memoTx.GetMemo()) - if uint64(memoLength) > params.MaxMemoCharacters { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge, - "maximum number of characters is %d but received %d characters", - params.MaxMemoCharacters, memoLength, - ) - } - - return next(ctx, tx, simulate) -} - -// ConsumeTxSizeGasDecorator will take in parameters and consume gas proportional -// to the size of tx before calling next AnteHandler. Note, the gas costs will be -// slightly over estimated due to the fact that any given signing account may need -// to be retrieved from state. -// -// CONTRACT: If simulate=true, then signatures must either be completely filled -// in or empty. -// CONTRACT: To use this decorator, signatures of transaction must be represented -// as legacytx.StdSignature otherwise simulate mode will incorrectly estimate gas cost. -type ConsumeTxSizeGasDecorator struct { - ak AccountKeeper -} - -func NewConsumeGasForTxSizeDecorator(ak AccountKeeper) ConsumeTxSizeGasDecorator { - return ConsumeTxSizeGasDecorator{ - ak: ak, - } -} - -func (cgts ConsumeTxSizeGasDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - sigTx, ok := tx.(authsigning.SigVerifiableTx) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") - } - params := cgts.ak.GetParams(ctx) - - ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*sdk.Gas(len(ctx.TxBytes())), "txSize") - - // simulate gas cost for signatures in simulate mode - if simulate { - // in simulate mode, each element should be a nil signature - sigs, err := sigTx.GetSignaturesV2() - if err != nil { - return ctx, err - } - n := len(sigs) - - for i, signer := range sigTx.GetSigners() { - // if signature is already filled in, no need to simulate gas cost - if i < n && !isIncompleteSignature(sigs[i].Data) { - continue - } - - var pubkey cryptotypes.PubKey - - acc := cgts.ak.GetAccount(ctx, signer) - - // use placeholder simSecp256k1Pubkey if sig is nil - if acc == nil || acc.GetPubKey() == nil { - pubkey = simSecp256k1Pubkey - } else { - pubkey = acc.GetPubKey() - } - - // use stdsignature to mock the size of a full signature - simSig := legacytx.StdSignature{ //nolint:staticcheck // this will be removed when proto is ready - Signature: simSecp256k1Sig[:], - PubKey: pubkey, - } - - sigBz := legacy.Cdc.MustMarshal(simSig) - cost := sdk.Gas(len(sigBz) + 6) - - // If the pubkey is a multi-signature pubkey, then we estimate for the maximum - // number of signers. - if _, ok := pubkey.(*multisig.LegacyAminoPubKey); ok { - cost *= params.TxSigLimit - } - - ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*cost, "txSize") - } - } - - return next(ctx, tx, simulate) -} - -// isIncompleteSignature tests whether SignatureData is fully filled in for simulation purposes -func isIncompleteSignature(data signing.SignatureData) bool { - if data == nil { - return true - } - - switch data := data.(type) { - case *signing.SingleSignatureData: - return len(data.Signature) == 0 - case *signing.MultiSignatureData: - if len(data.Signatures) == 0 { - return true - } - for _, s := range data.Signatures { - if isIncompleteSignature(s) { - return true - } - } - } - - return false -} - -type ( - // TxTimeoutHeightDecorator defines an AnteHandler decorator that checks for a - // tx height timeout. - TxTimeoutHeightDecorator struct{} - - // TxWithTimeoutHeight defines the interface a tx must implement in order for - // TxHeightTimeoutDecorator to process the tx. - TxWithTimeoutHeight interface { - sdk.Tx - - GetTimeoutHeight() uint64 - } -) - -// TxTimeoutHeightDecorator defines an AnteHandler decorator that checks for a -// tx height timeout. -func NewTxTimeoutHeightDecorator() TxTimeoutHeightDecorator { - return TxTimeoutHeightDecorator{} -} - -// AnteHandle implements an AnteHandler decorator for the TxHeightTimeoutDecorator -// type where the current block height is checked against the tx's height timeout. -// If a height timeout is provided (non-zero) and is less than the current block -// height, then an error is returned. -func (txh TxTimeoutHeightDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - timeoutTx, ok := tx.(TxWithTimeoutHeight) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "expected tx to implement TxWithTimeoutHeight") - } - - timeoutHeight := timeoutTx.GetTimeoutHeight() - if timeoutHeight > 0 && uint64(ctx.BlockHeight()) > timeoutHeight { - return ctx, sdkerrors.Wrapf( - sdkerrors.ErrTxTimeoutHeight, "block height: %d, timeout height: %d", ctx.BlockHeight(), timeoutHeight, - ) - } - - return next(ctx, tx, simulate) -} diff --git a/x/auth/ante/basic_test.go b/x/auth/ante/basic_test.go deleted file mode 100644 index 4a8cb830fdf6..000000000000 --- a/x/auth/ante/basic_test.go +++ /dev/null @@ -1,224 +0,0 @@ -package ante_test - -import ( - "strings" - - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - "github.com/cosmos/cosmos-sdk/crypto/types/multisig" - "github.com/cosmos/cosmos-sdk/testutil/testdata" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" -) - -func (suite *AnteTestSuite) TestValidateBasic() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{}, []uint64{}, []uint64{} - invalidTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - vbd := ante.NewValidateBasicDecorator() - antehandler := sdk.ChainAnteDecorators(vbd) - _, err = antehandler(suite.ctx, invalidTx, false) - - suite.Require().NotNil(err, "Did not error on invalid tx") - - privs, accNums, accSeqs = []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - validTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - _, err = antehandler(suite.ctx, validTx, false) - suite.Require().Nil(err, "ValidateBasicDecorator returned error on valid tx. err: %v", err) - - // test decorator skips on recheck - suite.ctx = suite.ctx.WithIsReCheckTx(true) - - // decorator should skip processing invalidTx on recheck and thus return nil-error - _, err = antehandler(suite.ctx, invalidTx, false) - - suite.Require().Nil(err, "ValidateBasicDecorator ran on ReCheck") -} - -func (suite *AnteTestSuite) TestValidateMemo() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - suite.txBuilder.SetMemo(strings.Repeat("01234567890", 500)) - invalidTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - // require that long memos get rejected - vmd := ante.NewValidateMemoDecorator(suite.app.AccountKeeper) - antehandler := sdk.ChainAnteDecorators(vmd) - _, err = antehandler(suite.ctx, invalidTx, false) - - suite.Require().NotNil(err, "Did not error on tx with high memo") - - suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10)) - validTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - // require small memos pass ValidateMemo Decorator - _, err = antehandler(suite.ctx, validTx, false) - suite.Require().Nil(err, "ValidateBasicDecorator returned error on valid tx. err: %v", err) -} - -func (suite *AnteTestSuite) TestConsumeGasForTxSize() { - suite.SetupTest(true) // setup - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - - cgtsd := ante.NewConsumeGasForTxSizeDecorator(suite.app.AccountKeeper) - antehandler := sdk.ChainAnteDecorators(cgtsd) - - testCases := []struct { - name string - sigV2 signing.SignatureV2 - }{ - {"SingleSignatureData", signing.SignatureV2{PubKey: priv1.PubKey()}}, - {"MultiSignatureData", signing.SignatureV2{PubKey: priv1.PubKey(), Data: multisig.NewMultisig(2)}}, - } - - for _, tc := range testCases { - suite.Run(tc.name, func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10)) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - txBytes, err := suite.clientCtx.TxConfig.TxJSONEncoder()(tx) - suite.Require().Nil(err, "Cannot marshal tx: %v", err) - - params := suite.app.AccountKeeper.GetParams(suite.ctx) - expectedGas := sdk.Gas(len(txBytes)) * params.TxSizeCostPerByte - - // Set suite.ctx with TxBytes manually - suite.ctx = suite.ctx.WithTxBytes(txBytes) - - // track how much gas is necessary to retrieve parameters - beforeGas := suite.ctx.GasMeter().GasConsumed() - suite.app.AccountKeeper.GetParams(suite.ctx) - afterGas := suite.ctx.GasMeter().GasConsumed() - expectedGas += afterGas - beforeGas - - beforeGas = suite.ctx.GasMeter().GasConsumed() - suite.ctx, err = antehandler(suite.ctx, tx, false) - suite.Require().Nil(err, "ConsumeTxSizeGasDecorator returned error: %v", err) - - // require that decorator consumes expected amount of gas - consumedGas := suite.ctx.GasMeter().GasConsumed() - beforeGas - suite.Require().Equal(expectedGas, consumedGas, "Decorator did not consume the correct amount of gas") - - // simulation must not underestimate gas of this decorator even with nil signatures - txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx) - suite.Require().NoError(err) - suite.Require().NoError(txBuilder.SetSignatures(tc.sigV2)) - tx = txBuilder.GetTx() - - simTxBytes, err := suite.clientCtx.TxConfig.TxJSONEncoder()(tx) - suite.Require().Nil(err, "Cannot marshal tx: %v", err) - // require that simulated tx is smaller than tx with signatures - suite.Require().True(len(simTxBytes) < len(txBytes), "simulated tx still has signatures") - - // Set suite.ctx with smaller simulated TxBytes manually - suite.ctx = suite.ctx.WithTxBytes(simTxBytes) - - beforeSimGas := suite.ctx.GasMeter().GasConsumed() - - // run antehandler with simulate=true - suite.ctx, err = antehandler(suite.ctx, tx, true) - consumedSimGas := suite.ctx.GasMeter().GasConsumed() - beforeSimGas - - // require that antehandler passes and does not underestimate decorator cost - suite.Require().Nil(err, "ConsumeTxSizeGasDecorator returned error: %v", err) - suite.Require().True(consumedSimGas >= expectedGas, "Simulate mode underestimates gas on AnteDecorator. Simulated cost: %d, expected cost: %d", consumedSimGas, expectedGas) - - }) - } - -} - -func (suite *AnteTestSuite) TestTxHeightTimeoutDecorator() { - suite.SetupTest(true) - - antehandler := sdk.ChainAnteDecorators(ante.NewTxTimeoutHeightDecorator()) - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - - testCases := []struct { - name string - timeout uint64 - height int64 - expectErr bool - }{ - {"default value", 0, 10, false}, - {"no timeout (greater height)", 15, 10, false}, - {"no timeout (same height)", 10, 10, false}, - {"timeout (smaller height)", 9, 10, true}, - } - - for _, tc := range testCases { - tc := tc - - suite.Run(tc.name, func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10)) - suite.txBuilder.SetTimeoutHeight(tc.timeout) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - ctx := suite.ctx.WithBlockHeight(tc.height) - _, err = antehandler(ctx, tx, true) - suite.Require().Equal(tc.expectErr, err != nil, err) - }) - } -} diff --git a/x/auth/ante/ext.go b/x/auth/ante/ext.go deleted file mode 100644 index 362b8d32a971..000000000000 --- a/x/auth/ante/ext.go +++ /dev/null @@ -1,36 +0,0 @@ -package ante - -import ( - codectypes "github.com/cosmos/cosmos-sdk/codec/types" - "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" -) - -type HasExtensionOptionsTx interface { - GetExtensionOptions() []*codectypes.Any - GetNonCriticalExtensionOptions() []*codectypes.Any -} - -// RejectExtensionOptionsDecorator is an AnteDecorator that rejects all extension -// options which can optionally be included in protobuf transactions. Users that -// need extension options should create a custom AnteHandler chain that handles -// needed extension options properly and rejects unknown ones. -type RejectExtensionOptionsDecorator struct{} - -// NewRejectExtensionOptionsDecorator creates a new RejectExtensionOptionsDecorator -func NewRejectExtensionOptionsDecorator() RejectExtensionOptionsDecorator { - return RejectExtensionOptionsDecorator{} -} - -var _ types.AnteDecorator = RejectExtensionOptionsDecorator{} - -// AnteHandle implements the AnteDecorator.AnteHandle method -func (r RejectExtensionOptionsDecorator) AnteHandle(ctx types.Context, tx types.Tx, simulate bool, next types.AnteHandler) (newCtx types.Context, err error) { - if hasExtOptsTx, ok := tx.(HasExtensionOptionsTx); ok { - if len(hasExtOptsTx.GetExtensionOptions()) != 0 { - return ctx, sdkerrors.ErrUnknownExtensionOptions - } - } - - return next(ctx, tx, simulate) -} diff --git a/x/auth/ante/ext_test.go b/x/auth/ante/ext_test.go deleted file mode 100644 index 89ce6a7d649f..000000000000 --- a/x/auth/ante/ext_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package ante_test - -import ( - "github.com/cosmos/cosmos-sdk/codec/types" - "github.com/cosmos/cosmos-sdk/testutil/testdata" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/x/auth/ante" - "github.com/cosmos/cosmos-sdk/x/auth/tx" -) - -func (suite *AnteTestSuite) TestRejectExtensionOptionsDecorator() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - reod := ante.NewRejectExtensionOptionsDecorator() - antehandler := sdk.ChainAnteDecorators(reod) - - // no extension options should not trigger an error - theTx := suite.txBuilder.GetTx() - _, err := antehandler(suite.ctx, theTx, false) - suite.Require().NoError(err) - - extOptsTxBldr, ok := suite.txBuilder.(tx.ExtensionOptionsTxBuilder) - if !ok { - // if we can't set extension options, this decorator doesn't apply and we're done - return - } - - // setting any extension option should cause an error - any, err := types.NewAnyWithValue(testdata.NewTestMsg()) - suite.Require().NoError(err) - extOptsTxBldr.SetExtensionOptions(any) - theTx = suite.txBuilder.GetTx() - _, err = antehandler(suite.ctx, theTx, false) - suite.Require().EqualError(err, "unknown extension options") -} diff --git a/x/auth/ante/fee.go b/x/auth/ante/fee.go deleted file mode 100644 index b1d1d72a770e..000000000000 --- a/x/auth/ante/fee.go +++ /dev/null @@ -1,140 +0,0 @@ -package ante - -import ( - "fmt" - - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/x/auth/types" -) - -// MempoolFeeDecorator will check if the transaction's fee is at least as large -// as the local validator's minimum gasFee (defined in validator config). -// If fee is too low, decorator returns error and tx is rejected from mempool. -// Note this only applies when ctx.CheckTx = true -// If fee is high enough or not CheckTx, then call next AnteHandler -// CONTRACT: Tx must implement FeeTx to use MempoolFeeDecorator -type MempoolFeeDecorator struct{} - -func NewMempoolFeeDecorator() MempoolFeeDecorator { - return MempoolFeeDecorator{} -} - -func (mfd MempoolFeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { - feeTx, ok := tx.(sdk.FeeTx) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") - } - - feeCoins := feeTx.GetFee() - gas := feeTx.GetGas() - - // Ensure that the provided fees meet a minimum threshold for the validator, - // if this is a CheckTx. This is only for local mempool purposes, and thus - // is only ran on check tx. - if ctx.IsCheckTx() && !simulate { - minGasPrices := ctx.MinGasPrices() - if !minGasPrices.IsZero() { - requiredFees := make(sdk.Coins, len(minGasPrices)) - - // Determine the required fees by multiplying each required minimum gas - // price by the gas limit, where fee = ceil(minGasPrice * gasLimit). - glDec := sdk.NewDec(int64(gas)) - for i, gp := range minGasPrices { - fee := gp.Amount.Mul(glDec) - requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt()) - } - - if !feeCoins.IsAnyGTE(requiredFees) { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees) - } - } - } - - return next(ctx, tx, simulate) -} - -// DeductFeeDecorator deducts fees from the first signer of the tx -// If the first signer does not have the funds to pay for the fees, return with InsufficientFunds error -// Call next AnteHandler if fees successfully deducted -// CONTRACT: Tx must implement FeeTx interface to use DeductFeeDecorator -type DeductFeeDecorator struct { - ak AccountKeeper - bankKeeper types.BankKeeper - feegrantKeeper FeegrantKeeper -} - -func NewDeductFeeDecorator(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper) DeductFeeDecorator { - return DeductFeeDecorator{ - ak: ak, - bankKeeper: bk, - feegrantKeeper: fk, - } -} - -func (dfd DeductFeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { - feeTx, ok := tx.(sdk.FeeTx) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") - } - - if addr := dfd.ak.GetModuleAddress(types.FeeCollectorName); addr == nil { - panic(fmt.Sprintf("%s module account has not been set", types.FeeCollectorName)) - } - - fee := feeTx.GetFee() - feePayer := feeTx.FeePayer() - feeGranter := feeTx.FeeGranter() - - deductFeesFrom := feePayer - - // if feegranter set deduct fee from feegranter account. - // this works with only when feegrant enabled. - if feeGranter != nil { - if dfd.feegrantKeeper == nil { - return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "fee grants are not enabled") - } else if !feeGranter.Equals(feePayer) { - err := dfd.feegrantKeeper.UseGrantedFees(ctx, feeGranter, feePayer, fee, tx.GetMsgs()) - - if err != nil { - return ctx, sdkerrors.Wrapf(err, "%s not allowed to pay fees from %s", feeGranter, feePayer) - } - } - - deductFeesFrom = feeGranter - } - - deductFeesFromAcc := dfd.ak.GetAccount(ctx, deductFeesFrom) - if deductFeesFromAcc == nil { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "fee payer address: %s does not exist", deductFeesFrom) - } - - // deduct the fees - if !feeTx.GetFee().IsZero() { - err = DeductFees(dfd.bankKeeper, ctx, deductFeesFromAcc, feeTx.GetFee()) - if err != nil { - return ctx, err - } - } - - events := sdk.Events{sdk.NewEvent(sdk.EventTypeTx, - sdk.NewAttribute(sdk.AttributeKeyFee, feeTx.GetFee().String()), - )} - ctx.EventManager().EmitEvents(events) - - return next(ctx, tx, simulate) -} - -// DeductFees deducts fees from the given account. -func DeductFees(bankKeeper types.BankKeeper, ctx sdk.Context, acc types.AccountI, fees sdk.Coins) error { - if !fees.IsValid() { - return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "invalid fee amount: %s", fees) - } - - err := bankKeeper.SendCoinsFromAccountToModule(ctx, acc.GetAddress(), types.FeeCollectorName, fees) - if err != nil { - return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, err.Error()) - } - - return nil -} diff --git a/x/auth/ante/fee_test.go b/x/auth/ante/fee_test.go deleted file mode 100644 index 06ccb4d3948f..000000000000 --- a/x/auth/ante/fee_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package ante_test - -import ( - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - "github.com/cosmos/cosmos-sdk/testutil/testdata" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/x/auth/ante" - "github.com/cosmos/cosmos-sdk/x/bank/testutil" -) - -func (suite *AnteTestSuite) TestEnsureMempoolFees() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - mfd := ante.NewMempoolFeeDecorator() - antehandler := sdk.ChainAnteDecorators(mfd) - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - // Set high gas price so standard test fee fails - atomPrice := sdk.NewDecCoinFromDec("atom", sdk.NewDec(200).Quo(sdk.NewDec(100000))) - highGasPrice := []sdk.DecCoin{atomPrice} - suite.ctx = suite.ctx.WithMinGasPrices(highGasPrice) - - // Set IsCheckTx to true - suite.ctx = suite.ctx.WithIsCheckTx(true) - - // antehandler errors with insufficient fees - _, err = antehandler(suite.ctx, tx, false) - suite.Require().NotNil(err, "Decorator should have errored on too low fee for local gasPrice") - - // Set IsCheckTx to false - suite.ctx = suite.ctx.WithIsCheckTx(false) - - // antehandler should not error since we do not check minGasPrice in DeliverTx - _, err = antehandler(suite.ctx, tx, false) - suite.Require().Nil(err, "MempoolFeeDecorator returned error in DeliverTx") - - // Set IsCheckTx back to true for testing sufficient mempool fee - suite.ctx = suite.ctx.WithIsCheckTx(true) - - atomPrice = sdk.NewDecCoinFromDec("atom", sdk.NewDec(0).Quo(sdk.NewDec(100000))) - lowGasPrice := []sdk.DecCoin{atomPrice} - suite.ctx = suite.ctx.WithMinGasPrices(lowGasPrice) - - _, err = antehandler(suite.ctx, tx, false) - suite.Require().Nil(err, "Decorator should not have errored on fee higher than local gasPrice") -} - -func (suite *AnteTestSuite) TestDeductFees() { - suite.SetupTest(false) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() - suite.Require().NoError(suite.txBuilder.SetMsgs(msg)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - // Set account with insufficient funds - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr1) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) - coins := sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(10))) - err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, coins) - suite.Require().NoError(err) - - dfd := ante.NewDeductFeeDecorator(suite.app.AccountKeeper, suite.app.BankKeeper, nil) - antehandler := sdk.ChainAnteDecorators(dfd) - - _, err = antehandler(suite.ctx, tx, false) - - suite.Require().NotNil(err, "Tx did not error when fee payer had insufficient funds") - - // Set account with sufficient funds - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) - err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(200)))) - suite.Require().NoError(err) - - _, err = antehandler(suite.ctx, tx, false) - - suite.Require().Nil(err, "Tx errored after account has been set with sufficient funds") -} diff --git a/x/auth/ante/setup.go b/x/auth/ante/setup.go deleted file mode 100644 index 737cc295b980..000000000000 --- a/x/auth/ante/setup.go +++ /dev/null @@ -1,76 +0,0 @@ -package ante - -import ( - "fmt" - - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" -) - -var ( - _ GasTx = (*legacytx.StdTx)(nil) // assert StdTx implements GasTx -) - -// GasTx defines a Tx with a GetGas() method which is needed to use SetUpContextDecorator -type GasTx interface { - sdk.Tx - GetGas() uint64 -} - -// SetUpContextDecorator sets the GasMeter in the Context and wraps the next AnteHandler with a defer clause -// to recover from any downstream OutOfGas panics in the AnteHandler chain to return an error with information -// on gas provided and gas used. -// CONTRACT: Must be first decorator in the chain -// CONTRACT: Tx must implement GasTx interface -type SetUpContextDecorator struct{} - -func NewSetUpContextDecorator() SetUpContextDecorator { - return SetUpContextDecorator{} -} - -func (sud SetUpContextDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { - // all transactions must implement GasTx - gasTx, ok := tx.(GasTx) - if !ok { - // Set a gas meter with limit 0 as to prevent an infinite gas meter attack - // during runTx. - newCtx = SetGasMeter(simulate, ctx, 0) - return newCtx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be GasTx") - } - - newCtx = SetGasMeter(simulate, ctx, gasTx.GetGas()) - - // Decorator will catch an OutOfGasPanic caused in the next antehandler - // AnteHandlers must have their own defer/recover in order for the BaseApp - // to know how much gas was used! This is because the GasMeter is created in - // the AnteHandler, but if it panics the context won't be set properly in - // runTx's recover call. - defer func() { - if r := recover(); r != nil { - switch rType := r.(type) { - case sdk.ErrorOutOfGas: - log := fmt.Sprintf( - "insufficient gas, gasOffered: %d, gasRequired: %d, code location: %v", - gasTx.GetGas(), newCtx.GasMeter().GasConsumed(), rType.Descriptor) - - err = sdkerrors.Wrap(sdkerrors.ErrOutOfGas, log) - default: - panic(r) - } - } - }() - - return next(newCtx, tx, simulate) -} - -// SetGasMeter returns a new context with a gas meter set from a given context. -func SetGasMeter(simulate bool, ctx sdk.Context, gasLimit uint64) sdk.Context { - // In various cases such as simulation and during the genesis block, we do not - // meter any gas utilization. - if simulate || ctx.BlockHeight() == 0 { - return ctx.WithGasMeter(sdk.NewInfiniteGasMeter()) - } - - return ctx.WithGasMeter(sdk.NewGasMeter(gasLimit)) -} diff --git a/x/auth/ante/testutil_test.go b/x/auth/ante/testutil_test.go deleted file mode 100644 index faf2e7cdf658..000000000000 --- a/x/auth/ante/testutil_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package ante_test - -import ( - "errors" - "fmt" - "testing" - - minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" - - "github.com/stretchr/testify/suite" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" - - "github.com/cosmos/cosmos-sdk/client" - "github.com/cosmos/cosmos-sdk/client/tx" - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - "github.com/cosmos/cosmos-sdk/simapp" - "github.com/cosmos/cosmos-sdk/testutil/testdata" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" - xauthsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" - "github.com/cosmos/cosmos-sdk/x/auth/types" - authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" -) - -// TestAccount represents an account used in the tests in x/auth/ante. -type TestAccount struct { - acc types.AccountI - priv cryptotypes.PrivKey -} - -// AnteTestSuite is a test suite to be used with ante handler tests. -type AnteTestSuite struct { - suite.Suite - - app *simapp.SimApp - anteHandler sdk.AnteHandler - ctx sdk.Context - clientCtx client.Context - txBuilder client.TxBuilder -} - -// returns context and app with params set on account keeper -func createTestApp(t *testing.T, isCheckTx bool) (*simapp.SimApp, sdk.Context) { - app := simapp.Setup(t, isCheckTx) - ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{}) - app.AccountKeeper.SetParams(ctx, authtypes.DefaultParams()) - - return app, ctx -} - -// SetupTest setups a new test, with new app, context, and anteHandler. -func (suite *AnteTestSuite) SetupTest(isCheckTx bool) { - suite.app, suite.ctx = createTestApp(suite.T(), isCheckTx) - suite.ctx = suite.ctx.WithBlockHeight(1) - - // Set up TxConfig. - encodingConfig := simapp.MakeTestEncodingConfig() - // We're using TestMsg encoding in some tests, so register it here. - encodingConfig.Amino.RegisterConcrete(&testdata.TestMsg{}, "testdata.TestMsg", nil) - testdata.RegisterInterfaces(encodingConfig.InterfaceRegistry) - - suite.clientCtx = client.Context{}. - WithTxConfig(encodingConfig.TxConfig) - - // We're not using ante.NewAnteHandler here because: - // - ante.NewAnteHandler doesn't have SetUpContextDecorator, as it has been - // moved to the gas TxMiddleware - // - whereas these tests have not been migrated to middlewares yet, so - // still need the SetUpContextDecorator. - // - // TODO: migrate all antehandler tests to middleware tests. - // https://github.com/cosmos/cosmos-sdk/issues/9585 - anteDecorators := []sdk.AnteDecorator{ - ante.NewSetUpContextDecorator(), - ante.NewRejectExtensionOptionsDecorator(), - ante.NewMempoolFeeDecorator(), - ante.NewValidateBasicDecorator(), - ante.NewTxTimeoutHeightDecorator(), - ante.NewValidateMemoDecorator(suite.app.AccountKeeper), - ante.NewConsumeGasForTxSizeDecorator(suite.app.AccountKeeper), - ante.NewDeductFeeDecorator(suite.app.AccountKeeper, suite.app.BankKeeper, suite.app.FeeGrantKeeper), - // SetPubKeyDecorator must be called before all signature verification decorators - ante.NewSetPubKeyDecorator(suite.app.AccountKeeper), - ante.NewValidateSigCountDecorator(suite.app.AccountKeeper), - ante.NewSigGasConsumeDecorator(suite.app.AccountKeeper, ante.DefaultSigVerificationGasConsumer), - ante.NewSigVerificationDecorator(suite.app.AccountKeeper, encodingConfig.TxConfig.SignModeHandler()), - ante.NewIncrementSequenceDecorator(suite.app.AccountKeeper), - } - - suite.anteHandler = sdk.ChainAnteDecorators(anteDecorators...) -} - -// CreateTestAccounts creates `numAccs` accounts, and return all relevant -// information about them including their private keys. -func (suite *AnteTestSuite) CreateTestAccounts(numAccs int) []TestAccount { - var accounts []TestAccount - - for i := 0; i < numAccs; i++ { - priv, _, addr := testdata.KeyTestPubAddr() - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) - err := acc.SetAccountNumber(uint64(i)) - suite.Require().NoError(err) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) - someCoins := sdk.Coins{ - sdk.NewInt64Coin("atom", 10000000), - } - err = suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, someCoins) - suite.Require().NoError(err) - - err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, addr, someCoins) - suite.Require().NoError(err) - - accounts = append(accounts, TestAccount{acc, priv}) - } - - return accounts -} - -// CreateTestTx is a helper function to create a tx given multiple inputs. -func (suite *AnteTestSuite) CreateTestTx(privs []cryptotypes.PrivKey, accNums []uint64, accSeqs []uint64, chainID string) (xauthsigning.Tx, error) { - // First round: we gather all the signer infos. We use the "set empty - // signature" hack to do that. - var sigsV2 []signing.SignatureV2 - for i, priv := range privs { - sigV2 := signing.SignatureV2{ - PubKey: priv.PubKey(), - Data: &signing.SingleSignatureData{ - SignMode: suite.clientCtx.TxConfig.SignModeHandler().DefaultMode(), - Signature: nil, - }, - Sequence: accSeqs[i], - } - - sigsV2 = append(sigsV2, sigV2) - } - err := suite.txBuilder.SetSignatures(sigsV2...) - if err != nil { - return nil, err - } - - // Second round: all signer infos are set, so each signer can sign. - sigsV2 = []signing.SignatureV2{} - for i, priv := range privs { - signerData := xauthsigning.SignerData{ - ChainID: chainID, - AccountNumber: accNums[i], - Sequence: accSeqs[i], - } - sigV2, err := tx.SignWithPrivKey( - suite.clientCtx.TxConfig.SignModeHandler().DefaultMode(), signerData, - suite.txBuilder, priv, suite.clientCtx.TxConfig, accSeqs[i]) - if err != nil { - return nil, err - } - - sigsV2 = append(sigsV2, sigV2) - } - err = suite.txBuilder.SetSignatures(sigsV2...) - if err != nil { - return nil, err - } - - return suite.txBuilder.GetTx(), nil -} - -// TestCase represents a test case used in test tables. -type TestCase struct { - desc string - malleate func() - simulate bool - expPass bool - expErr error -} - -// CreateTestTx is a helper function to create a tx given multiple inputs. -func (suite *AnteTestSuite) RunTestCase(privs []cryptotypes.PrivKey, msgs []sdk.Msg, feeAmount sdk.Coins, gasLimit uint64, accNums, accSeqs []uint64, chainID string, tc TestCase) { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - // Theoretically speaking, ante handler unit tests should only test - // ante handlers, but here we sometimes also test the tx creation - // process. - tx, txErr := suite.CreateTestTx(privs, accNums, accSeqs, chainID) - newCtx, anteErr := suite.anteHandler(suite.ctx, tx, tc.simulate) - - if tc.expPass { - suite.Require().NoError(txErr) - suite.Require().NoError(anteErr) - suite.Require().NotNil(newCtx) - - suite.ctx = newCtx - } else { - switch { - case txErr != nil: - suite.Require().Error(txErr) - suite.Require().True(errors.Is(txErr, tc.expErr)) - - case anteErr != nil: - suite.Require().Error(anteErr) - suite.Require().True(errors.Is(anteErr, tc.expErr)) - - default: - suite.Fail("expected one of txErr,anteErr to be an error") - } - } - }) -} - -func TestAnteTestSuite(t *testing.T) { - suite.Run(t, new(AnteTestSuite)) -} diff --git a/x/auth/middleware/basic.go b/x/auth/middleware/basic.go new file mode 100644 index 000000000000..04ea10bf416e --- /dev/null +++ b/x/auth/middleware/basic.go @@ -0,0 +1,358 @@ +package middleware + +import ( + "context" + + "github.com/cosmos/cosmos-sdk/codec/legacy" + "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" + authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" + abci "github.com/tendermint/tendermint/abci/types" +) + +type validateBasicTxHandler struct { + next tx.Handler +} + +// ValidateBasicMiddleware will call tx.ValidateBasic, msg.ValidateBasic(for each msg inside tx) +// and return any non-nil error. +// If ValidateBasic passes, middleware calls next middleware in chain. Note, +// validateBasicTxHandler will not get executed on ReCheckTx since it +// is not dependent on application state. +func ValidateBasicMiddleware(txh tx.Handler) tx.Handler { + return validateBasicTxHandler{ + next: txh, + } +} + +var _ tx.Handler = validateBasicTxHandler{} + +// validateBasicTxMsgs executes basic validator calls for messages. +func validateBasicTxMsgs(msgs []sdk.Msg) error { + if len(msgs) == 0 { + return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "must contain at least one message") + } + + for _, msg := range msgs { + err := msg.ValidateBasic() + if err != nil { + return err + } + } + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (txh validateBasicTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + // no need to validate basic on recheck tx, call next middleware + if req.Type == abci.CheckTxType_Recheck { + return txh.next.CheckTx(ctx, tx, req) + } + + if err := validateBasicTxMsgs(tx.GetMsgs()); err != nil { + return abci.ResponseCheckTx{}, err + } + + if err := tx.ValidateBasic(); err != nil { + return abci.ResponseCheckTx{}, err + } + + return txh.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (txh validateBasicTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := tx.ValidateBasic(); err != nil { + return abci.ResponseDeliverTx{}, err + } + + if err := validateBasicTxMsgs(tx.GetMsgs()); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return txh.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (txh validateBasicTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := sdkTx.ValidateBasic(); err != nil { + return tx.ResponseSimulateTx{}, err + } + + if err := validateBasicTxMsgs(sdkTx.GetMsgs()); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return txh.next.SimulateTx(ctx, sdkTx, req) +} + +var _ tx.Handler = txTimeoutHeightTxHandler{} + +type txTimeoutHeightTxHandler struct { + next tx.Handler +} + +// TxTimeoutHeightMiddleware defines a middleware that checks for a +// tx height timeout. +func TxTimeoutHeightMiddleware(txh tx.Handler) tx.Handler { + return txTimeoutHeightTxHandler{ + next: txh, + } +} + +func checkTimeout(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + timeoutTx, ok := tx.(sdk.TxWithTimeoutHeight) + if !ok { + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "expected tx to implement TxWithTimeoutHeight") + } + + timeoutHeight := timeoutTx.GetTimeoutHeight() + if timeoutHeight > 0 && uint64(sdkCtx.BlockHeight()) > timeoutHeight { + return sdkerrors.Wrapf( + sdkerrors.ErrTxTimeoutHeight, "block height: %d, timeout height: %d", sdkCtx.BlockHeight(), timeoutHeight, + ) + } + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (txh txTimeoutHeightTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := checkTimeout(ctx, tx); err != nil { + return abci.ResponseCheckTx{}, err + } + + return txh.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (txh txTimeoutHeightTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := checkTimeout(ctx, tx); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return txh.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (txh txTimeoutHeightTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := checkTimeout(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return txh.next.SimulateTx(ctx, sdkTx, req) +} + +type validateMemoTxHandler struct { + ak AccountKeeper + next tx.Handler +} + +// ValidateMemoMiddleware will validate memo given the parameters passed in +// If memo is too large middleware returns with error, otherwise call next middleware +// CONTRACT: Tx must implement TxWithMemo interface +func ValidateMemoMiddleware(ak AccountKeeper) tx.Middleware { + return func(txHandler tx.Handler) tx.Handler { + return validateMemoTxHandler{ + ak: ak, + next: txHandler, + } + } +} + +var _ tx.Handler = validateMemoTxHandler{} + +func (vmm validateMemoTxHandler) checkForValidMemo(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + memoTx, ok := tx.(sdk.TxWithMemo) + if !ok { + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") + } + + params := vmm.ak.GetParams(sdkCtx) + + memoLength := len(memoTx.GetMemo()) + if uint64(memoLength) > params.MaxMemoCharacters { + return sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge, + "maximum number of characters is %d but received %d characters", + params.MaxMemoCharacters, memoLength, + ) + } + + return nil +} + +// CheckTx implements tx.Handler.CheckTx method. +func (vmm validateMemoTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := vmm.checkForValidMemo(ctx, tx); err != nil { + return abci.ResponseCheckTx{}, err + } + + return vmm.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx method. +func (vmm validateMemoTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := vmm.checkForValidMemo(ctx, tx); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return vmm.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx method. +func (vmm validateMemoTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := vmm.checkForValidMemo(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return vmm.next.SimulateTx(ctx, sdkTx, req) +} + +var _ tx.Handler = consumeTxSizeGasTxHandler{} + +type consumeTxSizeGasTxHandler struct { + ak AccountKeeper + next tx.Handler +} + +// ConsumeTxSizeGasMiddleware will take in parameters and consume gas proportional +// to the size of tx before calling next middleware. Note, the gas costs will be +// slightly over estimated due to the fact that any given signing account may need +// to be retrieved from state. +// +// CONTRACT: If simulate=true, then signatures must either be completely filled +// in or empty. +// CONTRACT: To use this middleware, signatures of transaction must be represented +// as legacytx.StdSignature otherwise simulate mode will incorrectly estimate gas cost. +func ConsumeTxSizeGasMiddleware(ak AccountKeeper) tx.Middleware { + return func(txHandler tx.Handler) tx.Handler { + return consumeTxSizeGasTxHandler{ + ak: ak, + next: txHandler, + } + } +} + +func (cgts consumeTxSizeGasTxHandler) simulateSigGasCost(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + params := cgts.ak.GetParams(sdkCtx) + + sigTx, ok := tx.(authsigning.SigVerifiableTx) + if !ok { + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") + } + + // in simulate mode, each element should be a nil signature + sigs, err := sigTx.GetSignaturesV2() + if err != nil { + return err + } + n := len(sigs) + + for i, signer := range sigTx.GetSigners() { + // if signature is already filled in, no need to simulate gas cost + if i < n && !isIncompleteSignature(sigs[i].Data) { + continue + } + + var pubkey cryptotypes.PubKey + + acc := cgts.ak.GetAccount(sdkCtx, signer) + + // use placeholder simSecp256k1Pubkey if sig is nil + if acc == nil || acc.GetPubKey() == nil { + pubkey = simSecp256k1Pubkey + } else { + pubkey = acc.GetPubKey() + } + + // use stdsignature to mock the size of a full signature + simSig := legacytx.StdSignature{ //nolint:staticcheck // this will be removed when proto is ready + Signature: simSecp256k1Sig[:], + PubKey: pubkey, + } + + sigBz := legacy.Cdc.MustMarshal(simSig) + cost := sdk.Gas(len(sigBz) + 6) + + // If the pubkey is a multi-signature pubkey, then we estimate for the maximum + // number of signers. + if _, ok := pubkey.(*multisig.LegacyAminoPubKey); ok { + cost *= params.TxSigLimit + } + + sdkCtx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*cost, "txSize") + } + + return nil +} + +func (cgts consumeTxSizeGasTxHandler) consumeTxSizeGas(ctx context.Context, tx sdk.Tx, txBytes []byte, simulate bool) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + params := cgts.ak.GetParams(sdkCtx) + sdkCtx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*sdk.Gas(len(txBytes)), "txSize") + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (cgts consumeTxSizeGasTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := cgts.consumeTxSizeGas(ctx, tx, req.GetTx(), false); err != nil { + return abci.ResponseCheckTx{}, err + } + + return cgts.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (cgts consumeTxSizeGasTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := cgts.consumeTxSizeGas(ctx, tx, req.GetTx(), false); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return cgts.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (cgts consumeTxSizeGasTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := cgts.consumeTxSizeGas(ctx, sdkTx, req.TxBytes, true); err != nil { + return tx.ResponseSimulateTx{}, err + } + + if err := cgts.simulateSigGasCost(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return cgts.next.SimulateTx(ctx, sdkTx, req) +} + +// isIncompleteSignature tests whether SignatureData is fully filled in for simulation purposes +func isIncompleteSignature(data signing.SignatureData) bool { + if data == nil { + return true + } + + switch data := data.(type) { + case *signing.SingleSignatureData: + return len(data.Signature) == 0 + case *signing.MultiSignatureData: + if len(data.Signatures) == 0 { + return true + } + for _, s := range data.Signatures { + if isIncompleteSignature(s) { + return true + } + } + } + + return false +} diff --git a/x/auth/middleware/basic_test.go b/x/auth/middleware/basic_test.go new file mode 100644 index 000000000000..8e1ad1db6a30 --- /dev/null +++ b/x/auth/middleware/basic_test.go @@ -0,0 +1,222 @@ +package middleware_test + +import ( + "strings" + + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + "github.com/cosmos/cosmos-sdk/crypto/types/multisig" + "github.com/cosmos/cosmos-sdk/testutil/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" + "github.com/tendermint/tendermint/abci/types" +) + +func (s *MWTestSuite) TestValidateBasic() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ValidateBasicMiddleware) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + s.Require().NoError(txBuilder.SetMsgs(msg)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{}, []uint64{}, []uint64{} + invalidTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{}) + s.Require().NotNil(err, "Did not error on invalid tx") + + privs, accNums, accSeqs = []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + validTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), validTx, types.RequestDeliverTx{}) + s.Require().Nil(err, "ValidateBasicMiddleware returned error on valid tx. err: %v", err) + + // test middleware skips on recheck + ctx = ctx.WithIsReCheckTx(true) + + // middleware should skip processing invalidTx on recheck and thus return nil-error + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{}) + s.Require().Nil(err, "ValidateBasicMiddleware ran on ReCheck") +} + +func (s *MWTestSuite) TestValidateMemo() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ValidateMemoMiddleware(s.app.AccountKeeper)) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + s.Require().NoError(txBuilder.SetMsgs(msg)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + txBuilder.SetMemo(strings.Repeat("01234567890", 500)) + invalidTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + // require that long memos get rejected + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{}) + + s.Require().NotNil(err, "Did not error on tx with high memo") + + txBuilder.SetMemo(strings.Repeat("01234567890", 10)) + validTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + // require small memos pass ValidateMemo middleware + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), validTx, types.RequestDeliverTx{}) + s.Require().Nil(err, "ValidateBasicMiddleware returned error on valid tx. err: %v", err) +} + +func (s *MWTestSuite) TestConsumeGasForTxSize() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ConsumeTxSizeGasMiddleware(s.app.AccountKeeper)) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + + testCases := []struct { + name string + sigV2 signing.SignatureV2 + }{ + {"SingleSignatureData", signing.SignatureV2{PubKey: priv1.PubKey()}}, + {"MultiSignatureData", signing.SignatureV2{PubKey: priv1.PubKey(), Data: multisig.NewMultisig(2)}}, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + txBuilder = s.clientCtx.TxConfig.NewTxBuilder() + s.Require().NoError(txBuilder.SetMsgs(msg)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + txBuilder.SetMemo(strings.Repeat("01234567890", 10)) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + txBytes, err := s.clientCtx.TxConfig.TxJSONEncoder()(testTx) + s.Require().Nil(err, "Cannot marshal tx: %v", err) + + params := s.app.AccountKeeper.GetParams(ctx) + expectedGas := sdk.Gas(len(txBytes)) * params.TxSizeCostPerByte + + // Set ctx with TxBytes manually + ctx = ctx.WithTxBytes(txBytes) + + // track how much gas is necessary to retrieve parameters + beforeGas := ctx.GasMeter().GasConsumed() + s.app.AccountKeeper.GetParams(ctx) + afterGas := ctx.GasMeter().GasConsumed() + expectedGas += afterGas - beforeGas + + beforeGas = ctx.GasMeter().GasConsumed() + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, types.RequestDeliverTx{Tx: txBytes}) + + s.Require().Nil(err, "ConsumeTxSizeGasMiddleware returned error: %v", err) + + // require that middleware consumes expected amount of gas + consumedGas := ctx.GasMeter().GasConsumed() - beforeGas + s.Require().Equal(expectedGas, consumedGas, "Middleware did not consume the correct amount of gas") + + // simulation must not underestimate gas of this middleware even with nil signatures + txBuilder, err := s.clientCtx.TxConfig.WrapTxBuilder(testTx) + s.Require().NoError(err) + s.Require().NoError(txBuilder.SetSignatures(tc.sigV2)) + testTx = txBuilder.GetTx() + + simTxBytes, err := s.clientCtx.TxConfig.TxJSONEncoder()(testTx) + s.Require().Nil(err, "Cannot marshal tx: %v", err) + // require that simulated tx is smaller than tx with signatures + s.Require().True(len(simTxBytes) < len(txBytes), "simulated tx still has signatures") + + // Set s.ctx with smaller simulated TxBytes manually + ctx = ctx.WithTxBytes(simTxBytes) + + beforeSimGas := ctx.GasMeter().GasConsumed() + + // run txhandler in simulate mode + _, err = txHandler.SimulateTx(sdk.WrapSDKContext(ctx), testTx, tx.RequestSimulateTx{TxBytes: simTxBytes}) + consumedSimGas := ctx.GasMeter().GasConsumed() - beforeSimGas + + // require that txhandler passes and does not underestimate middleware cost + s.Require().Nil(err, "ConsumeTxSizeGasMiddleware returned error: %v", err) + s.Require().True(consumedSimGas >= expectedGas, "Simulate mode underestimates gas on Middleware. Simulated cost: %d, expected cost: %d", consumedSimGas, expectedGas) + }) + } +} + +func (s *MWTestSuite) TestTxHeightTimeoutMiddleware() { + ctx := s.SetupTest(true) + + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.TxTimeoutHeightMiddleware) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + + testCases := []struct { + name string + timeout uint64 + height int64 + expectErr bool + }{ + {"default value", 0, 10, false}, + {"no timeout (greater height)", 15, 10, false}, + {"no timeout (same height)", 10, 10, false}, + {"timeout (smaller height)", 9, 10, true}, + } + + for _, tc := range testCases { + tc := tc + + s.Run(tc.name, func() { + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + s.Require().NoError(txBuilder.SetMsgs(msg)) + + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + txBuilder.SetMemo(strings.Repeat("01234567890", 10)) + txBuilder.SetTimeoutHeight(tc.timeout) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + ctx := ctx.WithBlockHeight(tc.height) + _, err = txHandler.SimulateTx(sdk.WrapSDKContext(ctx), testTx, tx.RequestSimulateTx{}) + s.Require().Equal(tc.expectErr, err != nil, err) + }) + } +} diff --git a/x/auth/ante/expected_keepers.go b/x/auth/middleware/expected_keepers.go similarity index 85% rename from x/auth/ante/expected_keepers.go rename to x/auth/middleware/expected_keepers.go index 4dbbbd21c713..33bb6339c1f3 100644 --- a/x/auth/ante/expected_keepers.go +++ b/x/auth/middleware/expected_keepers.go @@ -1,4 +1,4 @@ -package ante +package middleware import ( sdk "github.com/cosmos/cosmos-sdk/types" @@ -6,7 +6,7 @@ import ( ) // AccountKeeper defines the contract needed for AccountKeeper related APIs. -// Interface provides support to use non-sdk AccountKeeper for AnteHandler's decorators. +// Interface provides support to use non-sdk AccountKeeper for TxHandler's middlewares. type AccountKeeper interface { GetParams(ctx sdk.Context) (params types.Params) GetAccount(ctx sdk.Context, addr sdk.AccAddress) types.AccountI diff --git a/x/auth/middleware/ext.go b/x/auth/middleware/ext.go new file mode 100644 index 000000000000..3fec1f674a47 --- /dev/null +++ b/x/auth/middleware/ext.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "context" + + abci "github.com/tendermint/tendermint/abci/types" + + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/tx" +) + +type HasExtensionOptionsTx interface { + GetExtensionOptions() []*codectypes.Any + GetNonCriticalExtensionOptions() []*codectypes.Any +} + +type rejectExtensionOptionsTxHandler struct { + next tx.Handler +} + +// RejectExtensionOptionsMiddleware creates a new rejectExtensionOptionsMiddleware. +// rejectExtensionOptionsMiddleware is a middleware that rejects all extension +// options which can optionally be included in protobuf transactions. Users that +// need extension options should create a custom middleware chain that handles +// needed extension options properly and rejects unknown ones. +func RejectExtensionOptionsMiddleware(txh tx.Handler) tx.Handler { + return rejectExtensionOptionsTxHandler{ + next: txh, + } +} + +var _ tx.Handler = rejectExtensionOptionsTxHandler{} + +func checkExtOpts(tx sdk.Tx) error { + if hasExtOptsTx, ok := tx.(HasExtensionOptionsTx); ok { + if len(hasExtOptsTx.GetExtensionOptions()) != 0 { + return sdkerrors.ErrUnknownExtensionOptions + } + } + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (txh rejectExtensionOptionsTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := checkExtOpts(tx); err != nil { + return abci.ResponseCheckTx{}, err + } + + return txh.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (txh rejectExtensionOptionsTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := checkExtOpts(tx); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return txh.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx method. +func (txh rejectExtensionOptionsTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := checkExtOpts(sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return txh.next.SimulateTx(ctx, sdkTx, req) +} diff --git a/x/auth/middleware/ext_test.go b/x/auth/middleware/ext_test.go new file mode 100644 index 000000000000..a2fd323e3bb8 --- /dev/null +++ b/x/auth/middleware/ext_test.go @@ -0,0 +1,36 @@ +package middleware_test + +import ( + "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/testutil/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" + "github.com/cosmos/cosmos-sdk/x/auth/tx" + abci "github.com/tendermint/tendermint/abci/types" +) + +func (s *MWTestSuite) TestRejectExtensionOptionsMiddleware() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.RejectExtensionOptionsMiddleware) + + // no extension options should not trigger an error + theTx := txBuilder.GetTx() + _, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), theTx, abci.RequestCheckTx{}) + s.Require().NoError(err) + + extOptsTxBldr, ok := txBuilder.(tx.ExtensionOptionsTxBuilder) + if !ok { + // if we can't set extension options, this middleware doesn't apply and we're done + return + } + + // setting any extension option should cause an error + any, err := types.NewAnyWithValue(testdata.NewTestMsg()) + s.Require().NoError(err) + extOptsTxBldr.SetExtensionOptions(any) + theTx = txBuilder.GetTx() + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), theTx, abci.RequestCheckTx{}) + s.Require().EqualError(err, "unknown extension options") +} diff --git a/x/auth/middleware/fee.go b/x/auth/middleware/fee.go new file mode 100644 index 000000000000..7285d530cfb2 --- /dev/null +++ b/x/auth/middleware/fee.go @@ -0,0 +1,194 @@ +package middleware + +import ( + "context" + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + + "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" +) + +var _ tx.Handler = mempoolFeeTxHandler{} + +type mempoolFeeTxHandler struct { + next tx.Handler +} + +// MempoolFeeMiddleware will check if the transaction's fee is at least as large +// as the local validator's minimum gasFee (defined in validator config). +// If fee is too low, middleware returns error and tx is rejected from mempool. +// Note this only applies when ctx.CheckTx = true +// If fee is high enough or not CheckTx, then call next middleware +// CONTRACT: Tx must implement FeeTx to use MempoolFeeMiddleware +func MempoolFeeMiddleware(txh tx.Handler) tx.Handler { + return mempoolFeeTxHandler{ + next: txh, + } +} + +// CheckTx implements tx.Handler.CheckTx. +func (txh mempoolFeeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + sdkCtx := sdk.UnwrapSDKContext(ctx) + + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return abci.ResponseCheckTx{}, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") + } + + feeCoins := feeTx.GetFee() + gas := feeTx.GetGas() + + // Ensure that the provided fees meet a minimum threshold for the validator, + // if this is a CheckTx. This is only for local mempool purposes, and thus + // is only ran on check tx. + minGasPrices := sdkCtx.MinGasPrices() + if !minGasPrices.IsZero() { + requiredFees := make(sdk.Coins, len(minGasPrices)) + + // Determine the required fees by multiplying each required minimum gas + // price by the gas limit, where fee = ceil(minGasPrice * gasLimit). + glDec := sdk.NewDec(int64(gas)) + for i, gp := range minGasPrices { + fee := gp.Amount.Mul(glDec) + requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt()) + } + + if !feeCoins.IsAnyGTE(requiredFees) { + return abci.ResponseCheckTx{}, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees) + } + } + + return txh.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (txh mempoolFeeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + return txh.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (txh mempoolFeeTxHandler) SimulateTx(ctx context.Context, tx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + return txh.next.SimulateTx(ctx, tx, req) +} + +var _ tx.Handler = deductFeeTxHandler{} + +type deductFeeTxHandler struct { + accountKeeper AccountKeeper + bankKeeper types.BankKeeper + feegrantKeeper FeegrantKeeper + next tx.Handler +} + +// DeductFeeMiddleware deducts fees from the first signer of the tx +// If the first signer does not have the funds to pay for the fees, return with InsufficientFunds error +// Call next middleware if fees successfully deducted +// CONTRACT: Tx must implement FeeTx interface to use deductFeeTxHandler +func DeductFeeMiddleware(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper) tx.Middleware { + return func(txh tx.Handler) tx.Handler { + return deductFeeTxHandler{ + accountKeeper: ak, + bankKeeper: bk, + feegrantKeeper: fk, + next: txh, + } + } +} + +func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") + } + + if addr := dfd.accountKeeper.GetModuleAddress(types.FeeCollectorName); addr == nil { + panic(fmt.Sprintf("%s module account has not been set", types.FeeCollectorName)) + } + + fee := feeTx.GetFee() + feePayer := feeTx.FeePayer() + feeGranter := feeTx.FeeGranter() + + deductFeesFrom := feePayer + + // if feegranter set deduct fee from feegranter account. + // this works with only when feegrant enabled. + if feeGranter != nil { + if dfd.feegrantKeeper == nil { + return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "fee grants are not enabled") + } else if !feeGranter.Equals(feePayer) { + err := dfd.feegrantKeeper.UseGrantedFees(sdkCtx, feeGranter, feePayer, fee, tx.GetMsgs()) + + if err != nil { + return sdkerrors.Wrapf(err, "%s not allowed to pay fees from %s", feeGranter, feePayer) + } + } + + deductFeesFrom = feeGranter + } + + deductFeesFromAcc := dfd.accountKeeper.GetAccount(sdkCtx, deductFeesFrom) + if deductFeesFromAcc == nil { + return sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "fee payer address: %s does not exist", deductFeesFrom) + } + + // deduct the fees + if !feeTx.GetFee().IsZero() { + err := DeductFees(dfd.bankKeeper, sdkCtx, deductFeesFromAcc, feeTx.GetFee()) + if err != nil { + return err + } + } + + events := sdk.Events{sdk.NewEvent(sdk.EventTypeTx, + sdk.NewAttribute(sdk.AttributeKeyFee, feeTx.GetFee().String()), + )} + sdkCtx.EventManager().EmitEvents(events) + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (dfd deductFeeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := dfd.checkDeductFee(ctx, tx); err != nil { + return abci.ResponseCheckTx{}, err + } + + return dfd.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (dfd deductFeeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := dfd.checkDeductFee(ctx, tx); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return dfd.next.DeliverTx(ctx, tx, req) +} + +func (dfd deductFeeTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := dfd.checkDeductFee(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return dfd.next.SimulateTx(ctx, sdkTx, req) +} + +// DeductFees deducts fees from the given account. +func DeductFees(bankKeeper types.BankKeeper, ctx sdk.Context, acc types.AccountI, fees sdk.Coins) error { + if !fees.IsValid() { + return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "invalid fee amount: %s", fees) + } + + err := bankKeeper.SendCoinsFromAccountToModule(ctx, acc.GetAddress(), types.FeeCollectorName, fees) + if err != nil { + return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, err.Error()) + } + + return nil +} diff --git a/x/auth/middleware/fee_test.go b/x/auth/middleware/fee_test.go new file mode 100644 index 000000000000..219d790d0d61 --- /dev/null +++ b/x/auth/middleware/fee_test.go @@ -0,0 +1,99 @@ +package middleware_test + +import ( + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + "github.com/cosmos/cosmos-sdk/testutil/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" + "github.com/cosmos/cosmos-sdk/x/bank/testutil" + abci "github.com/tendermint/tendermint/abci/types" +) + +func (s *MWTestSuite) TestEnsureMempoolFees() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.MempoolFeeMiddleware) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + s.Require().NoError(txBuilder.SetMsgs(msg)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + // Set high gas price so standard test fee fails + atomPrice := sdk.NewDecCoinFromDec("atom", sdk.NewDec(200).Quo(sdk.NewDec(100000))) + highGasPrice := []sdk.DecCoin{atomPrice} + ctx = ctx.WithMinGasPrices(highGasPrice) + + // txHandler errors with insufficient fees + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{}) + s.Require().NotNil(err, "Middleware should have errored on too low fee for local gasPrice") + + // txHandler should not error since we do not check minGasPrice in DeliverTx + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{}) + s.Require().Nil(err, "MempoolFeeMiddleware returned error in DeliverTx") + + atomPrice = sdk.NewDecCoinFromDec("atom", sdk.NewDec(0).Quo(sdk.NewDec(100000))) + lowGasPrice := []sdk.DecCoin{atomPrice} + ctx = ctx.WithMinGasPrices(lowGasPrice) + + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{}) + s.Require().Nil(err, "Middleware should not have errored on fee higher than local gasPrice") +} + +func (s *MWTestSuite) TestDeductFees() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.DeductFeeMiddleware( + s.app.AccountKeeper, + s.app.BankKeeper, + s.app.FeeGrantKeeper, + ), + ) + + // keys and addresses + priv1, _, addr1 := testdata.KeyTestPubAddr() + + // msg and signatures + msg := testdata.NewTestMsg(addr1) + feeAmount := testdata.NewTestFeeAmount() + gasLimit := testdata.NewTestGasLimit() + s.Require().NoError(txBuilder.SetMsgs(msg)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + // Set account with insufficient funds + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr1) + s.app.AccountKeeper.SetAccount(ctx, acc) + coins := sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(10))) + err = testutil.FundAccount(s.app.BankKeeper, ctx, addr1, coins) + s.Require().NoError(err) + + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{}) + s.Require().NotNil(err, "Tx did not error when fee payer had insufficient funds") + + // Set account with sufficient funds + s.app.AccountKeeper.SetAccount(ctx, acc) + err = testutil.FundAccount(s.app.BankKeeper, ctx, addr1, sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(200)))) + s.Require().NoError(err) + + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{}) + + s.Require().Nil(err, "Tx errored after account has been set with sufficient funds") +} diff --git a/x/auth/ante/feegrant_test.go b/x/auth/middleware/feegrant_test.go similarity index 82% rename from x/auth/ante/feegrant_test.go rename to x/auth/middleware/feegrant_test.go index 7c03e3dbefd6..d45b44160496 100644 --- a/x/auth/ante/feegrant_test.go +++ b/x/auth/middleware/feegrant_test.go @@ -1,10 +1,11 @@ -package ante_test +package middleware_test import ( "math/rand" "testing" "time" + abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/crypto" "github.com/cosmos/cosmos-sdk/client" @@ -15,7 +16,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/simulation" "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" authsign "github.com/cosmos/cosmos-sdk/x/auth/signing" "github.com/cosmos/cosmos-sdk/x/auth/tx" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -23,19 +24,20 @@ import ( "github.com/cosmos/cosmos-sdk/x/feegrant" ) -func (suite *AnteTestSuite) TestDeductFeesNoDelegation() { - suite.SetupTest(false) - // setup - app, ctx := suite.app, suite.ctx +func (s *MWTestSuite) TestDeductFeesNoDelegation() { + ctx := s.SetupTest(false) // setup + app := s.app protoTxCfg := tx.NewTxConfig(codec.NewProtoCodec(app.InterfaceRegistry()), tx.DefaultSignModes) - // this just tests our handler - dfd := ante.NewDeductFeeDecorator(app.AccountKeeper, app.BankKeeper, app.FeeGrantKeeper) - feeAnteHandler := sdk.ChainAnteDecorators(dfd) - - // this tests the whole stack - anteHandlerStack := suite.anteHandler + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.DeductFeeMiddleware( + s.app.AccountKeeper, + s.app.BankKeeper, + s.app.FeeGrantKeeper, + ), + ) // keys and addresses priv1, _, addr1 := testdata.KeyTestPubAddr() @@ -45,24 +47,24 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() { priv5, _, addr5 := testdata.KeyTestPubAddr() // Set addr1 with insufficient funds - err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(10))}) - suite.Require().NoError(err) + err := testutil.FundAccount(s.app.BankKeeper, ctx, addr1, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(10))}) + s.Require().NoError(err) // Set addr2 with more funds - err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr2, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(99999))}) - suite.Require().NoError(err) + err = testutil.FundAccount(s.app.BankKeeper, ctx, addr2, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(99999))}) + s.Require().NoError(err) // grant fee allowance from `addr2` to `addr3` (plenty to pay) err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr3, &feegrant.BasicAllowance{ SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 500)), }) - suite.Require().NoError(err) + s.Require().NoError(err) // grant low fee allowance (20atom), to check the tx requesting more than allowed. err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr4, &feegrant.BasicAllowance{ SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 20)), }) - suite.Require().NoError(err) + s.Require().NoError(err) cases := map[string]struct { signerKey cryptotypes.PrivKey @@ -133,7 +135,7 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() { for name, stc := range cases { tc := stc // to make scopelint happy - suite.T().Run(name, func(t *testing.T) { + s.T().Run(name, func(t *testing.T) { fee := sdk.NewCoins(sdk.NewInt64Coin("atom", tc.fee)) msgs := []sdk.Msg{testdata.NewTestMsg(tc.signer)} @@ -144,19 +146,22 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() { } tx, err := genTxWithFeeGranter(protoTxCfg, msgs, fee, helpers.DefaultGenTxGas, ctx.ChainID(), accNums, seqs, tc.feeAccount, privs...) - suite.Require().NoError(err) - _, err = feeAnteHandler(ctx, tx, false) // tests only feegrant ante + s.Require().NoError(err) + + // tests only feegrant middleware + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{}) if tc.valid { - suite.Require().NoError(err) + s.Require().NoError(err) } else { - suite.Require().Error(err) + s.Require().Error(err) } - _, err = anteHandlerStack(ctx, tx, false) // tests while stack + // tests while stack + _, err = s.txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{}) if tc.valid { - suite.Require().NoError(err) + s.Require().NoError(err) } else { - suite.Require().Error(err) + s.Require().Error(err) } }) } diff --git a/x/auth/middleware/gas_test.go b/x/auth/middleware/gas_test.go index 287160b10c3d..30534ad7e8ad 100644 --- a/x/auth/middleware/gas_test.go +++ b/x/auth/middleware/gas_test.go @@ -70,7 +70,7 @@ func (s *MWTestSuite) TestSetup() { if tc.expErr { s.Require().EqualError(err, tc.errorStr) } else { - s.Require().Nil(err, "SetUpContextDecorator returned error") + s.Require().Nil(err, "SetUpContextMiddleware returned error") s.Require().Equal(tc.expGasLimit, uint64(res.GasWanted)) } }) diff --git a/x/auth/middleware/legacy_ante.go b/x/auth/middleware/legacy_ante.go deleted file mode 100644 index 14f682082994..000000000000 --- a/x/auth/middleware/legacy_ante.go +++ /dev/null @@ -1,115 +0,0 @@ -package middleware - -import ( - "context" - - abci "github.com/tendermint/tendermint/abci/types" - - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/types/tx" -) - -type legacyAnteTxHandler struct { - anteHandler sdk.AnteHandler - inner tx.Handler -} - -func newLegacyAnteMiddleware(anteHandler sdk.AnteHandler) tx.Middleware { - return func(txHandler tx.Handler) tx.Handler { - return legacyAnteTxHandler{ - anteHandler: anteHandler, - inner: txHandler, - } - } -} - -var _ tx.Handler = legacyAnteTxHandler{} - -// CheckTx implements tx.Handler.CheckTx method. -func (txh legacyAnteTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { - sdkCtx, err := txh.runAnte(ctx, tx, req.Tx, false) - if err != nil { - return abci.ResponseCheckTx{}, err - } - - return txh.inner.CheckTx(sdk.WrapSDKContext(sdkCtx), tx, req) -} - -// DeliverTx implements tx.Handler.DeliverTx method. -func (txh legacyAnteTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { - sdkCtx, err := txh.runAnte(ctx, tx, req.Tx, false) - if err != nil { - return abci.ResponseDeliverTx{}, err - } - - return txh.inner.DeliverTx(sdk.WrapSDKContext(sdkCtx), tx, req) -} - -// SimulateTx implements tx.Handler.SimulateTx method. -func (txh legacyAnteTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { - sdkCtx, err := txh.runAnte(ctx, sdkTx, req.TxBytes, true) - if err != nil { - return tx.ResponseSimulateTx{}, err - } - - return txh.inner.SimulateTx(sdk.WrapSDKContext(sdkCtx), sdkTx, req) -} - -func (txh legacyAnteTxHandler) runAnte(ctx context.Context, tx sdk.Tx, txBytes []byte, isSimulate bool) (sdk.Context, error) { - err := validateBasicTxMsgs(tx.GetMsgs()) - if err != nil { - return sdk.Context{}, err - } - - sdkCtx := sdk.UnwrapSDKContext(ctx) - if txh.anteHandler == nil { - return sdkCtx, nil - } - - ms := sdkCtx.MultiStore() - - // Branch context before AnteHandler call in case it aborts. - // This is required for both CheckTx and DeliverTx. - // Ref: https://github.com/cosmos/cosmos-sdk/issues/2772 - // - // NOTE: Alternatively, we could require that AnteHandler ensures that - // writes do not happen if aborted/failed. This may have some - // performance benefits, but it'll be more difficult to get right. - anteCtx, msCache := cacheTxContext(sdkCtx, txBytes) - anteCtx = anteCtx.WithEventManager(sdk.NewEventManager()) - newCtx, err := txh.anteHandler(anteCtx, tx, isSimulate) - if err != nil { - return sdk.Context{}, err - } - - if !newCtx.IsZero() { - // At this point, newCtx.MultiStore() is a store branch, or something else - // replaced by the AnteHandler. We want the original multistore. - // - // Also, in the case of the tx aborting, we need to track gas consumed via - // the instantiated gas meter in the AnteHandler, so we update the context - // prior to returning. - sdkCtx = newCtx.WithMultiStore(ms) - } - - msCache.Write() - - return sdkCtx, nil -} - -// validateBasicTxMsgs executes basic validator calls for messages. -func validateBasicTxMsgs(msgs []sdk.Msg) error { - if len(msgs) == 0 { - return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "must contain at least one message") - } - - for _, msg := range msgs { - err := msg.ValidateBasic() - if err != nil { - return err - } - } - - return nil -} diff --git a/x/auth/middleware/middleware.go b/x/auth/middleware/middleware.go index 80ad606cd219..60820bb46cc5 100644 --- a/x/auth/middleware/middleware.go +++ b/x/auth/middleware/middleware.go @@ -2,7 +2,11 @@ package middleware import ( sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" + "github.com/cosmos/cosmos-sdk/x/auth/types" ) // ComposeMiddlewares compose multiple middlewares on top of a tx.Handler. The @@ -35,12 +39,33 @@ type TxHandlerOptions struct { LegacyRouter sdk.Router MsgServiceRouter *MsgServiceRouter - LegacyAnteHandler sdk.AnteHandler + AccountKeeper AccountKeeper + BankKeeper types.BankKeeper + FeegrantKeeper FeegrantKeeper + SignModeHandler authsigning.SignModeHandler + SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error } // NewDefaultTxHandler defines a TxHandler middleware stacks that should work // for most applications. func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { + if options.AccountKeeper == nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "account keeper is required for compose middlewares") + } + + if options.BankKeeper == nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "bank keeper is required for compose middlewares") + } + + if options.SignModeHandler == nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "sign mode handler is required for compose middlewares") + } + + var sigGasConsumer = options.SigGasConsumer + if sigGasConsumer == nil { + sigGasConsumer = DefaultSigVerificationGasConsumer + } + return ComposeMiddlewares( NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter), // Set a new GasMeter on sdk.Context. @@ -55,8 +80,19 @@ func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { // Choose which events to index in Tendermint. Make sure no events are // emitted outside of this middleware. NewIndexEventsTxMiddleware(options.IndexEvents), - // Temporary middleware to bundle antehandlers. - // TODO Remove in https://github.com/cosmos/cosmos-sdk/issues/9585. - newLegacyAnteMiddleware(options.LegacyAnteHandler), + // Reject all extension options which can optionally be included in the + // tx. + RejectExtensionOptionsMiddleware, + MempoolFeeMiddleware, + ValidateBasicMiddleware, + TxTimeoutHeightMiddleware, + ValidateMemoMiddleware(options.AccountKeeper), + ConsumeTxSizeGasMiddleware(options.AccountKeeper), + DeductFeeMiddleware(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper), + SetPubKeyMiddleware(options.AccountKeeper), + ValidateSigCountMiddleware(options.AccountKeeper), + SigGasConsumeMiddleware(options.AccountKeeper, sigGasConsumer), + SigVerificationMiddleware(options.AccountKeeper, options.SignModeHandler), + IncrementSequenceMiddleware(options.AccountKeeper), ), nil } diff --git a/x/auth/ante/ante_test.go b/x/auth/middleware/middleware_test.go similarity index 67% rename from x/auth/ante/ante_test.go rename to x/auth/middleware/middleware_test.go index 3ce43a5a17e3..e3b04983b6f5 100644 --- a/x/auth/ante/ante_test.go +++ b/x/auth/middleware/middleware_test.go @@ -1,4 +1,4 @@ -package ante_test +package middleware_test import ( "encoding/json" @@ -7,8 +7,6 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" @@ -17,18 +15,21 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/cosmos-sdk/x/bank/testutil" minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" + "github.com/stretchr/testify/require" + abci "github.com/tendermint/tendermint/abci/types" ) // Test that simulate transaction accurately estimates gas cost -func (suite *AnteTestSuite) TestSimulateGasCost() { - suite.SetupTest(false) // reset +func (s *MWTestSuite) TestSimulateGasCost() { + ctx := s.SetupTest(false) // reset + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(3) + accounts := s.createTestAccounts(ctx, 3) msgs := []sdk.Msg{ testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress()), testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress()), @@ -44,8 +45,8 @@ func (suite *AnteTestSuite) TestSimulateGasCost() { { "tx with 150atom fee", func() { - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) }, true, true, @@ -54,11 +55,11 @@ func (suite *AnteTestSuite) TestSimulateGasCost() { { "with previously estimated gas", func() { - simulatedGas := suite.ctx.GasMeter().GasConsumed() + simulatedGas := ctx.GasMeter().GasConsumed() accSeqs = []uint64{1, 1, 1} - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(simulatedGas) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(simulatedGas) }, false, true, @@ -67,18 +68,18 @@ func (suite *AnteTestSuite) TestSimulateGasCost() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } -// Test various error cases in the AnteHandler control flow. -func (suite *AnteTestSuite) TestAnteHandlerSigErrors() { - suite.SetupTest(false) // reset +// Test various error cases in the TxHandler control flow. +func (s *MWTestSuite) TestTxHandlerSigErrors() { + ctx := s.SetupTest(false) // reset + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases priv0, _, addr0 := testdata.KeyTestPubAddr() @@ -105,12 +106,12 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() { privs, accNums, accSeqs = []cryptotypes.PrivKey{}, []uint64{}, []uint64{} // Create tx manually to test the tx's signers - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) // tx.GetSigners returns addresses in correct order: addr1, addr2, addr3 expectedSigners := []sdk.AccAddress{addr0, addr1, addr2} - suite.Require().Equal(expectedSigners, tx.GetSigners()) + s.Require().Equal(expectedSigners, tx.GetSigners()) }, false, false, @@ -137,12 +138,12 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() { { "save the first account, but second is still unrecognized", func() { - acc1 := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr0) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc1) - err := suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, feeAmount) - suite.Require().NoError(err) - err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, addr0, feeAmount) - suite.Require().NoError(err) + acc1 := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr0) + s.app.AccountKeeper.SetAccount(ctx, acc1) + err := s.app.BankKeeper.MintCoins(ctx, minttypes.ModuleName, feeAmount) + s.Require().NoError(err) + err = s.app.BankKeeper.SendCoinsFromModuleToAccount(ctx, minttypes.ModuleName, addr0, feeAmount) + s.Require().NoError(err) }, false, false, @@ -151,21 +152,21 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test logic around account number checking with one signer and many signers. -func (suite *AnteTestSuite) TestAnteHandlerAccountNumbers() { - suite.SetupTest(false) // reset +func (s *MWTestSuite) TestTxHandlerAccountNumbers() { + ctx := s.SetupTest(false) // reset + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(2) + accounts := s.createTestAccounts(ctx, 2) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() @@ -232,22 +233,22 @@ func (suite *AnteTestSuite) TestAnteHandlerAccountNumbers() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test logic around account number checking with many signers when BlockHeight is 0. -func (suite *AnteTestSuite) TestAnteHandlerAccountNumbersAtBlockHeightZero() { - suite.SetupTest(false) // setup - suite.ctx = suite.ctx.WithBlockHeight(0) +func (s *MWTestSuite) TestTxHandlerAccountNumbersAtBlockHeightZero() { + ctx := s.SetupTest(false) // setup + ctx = ctx.WithBlockHeight(0) + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(2) + accounts := s.createTestAccounts(ctx, 2) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() @@ -316,21 +317,21 @@ func (suite *AnteTestSuite) TestAnteHandlerAccountNumbersAtBlockHeightZero() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test logic around sequence checking with one signer and many signers. -func (suite *AnteTestSuite) TestAnteHandlerSequences() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerSequences() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(3) + accounts := s.createTestAccounts(ctx, 3) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() @@ -428,24 +429,24 @@ func (suite *AnteTestSuite) TestAnteHandlerSequences() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test logic around fee deduction. -func (suite *AnteTestSuite) TestAnteHandlerFees() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerFees() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases priv0, _, addr0 := testdata.KeyTestPubAddr() - acc1 := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr0) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc1) + acc1 := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr0) + s.app.AccountKeeper.SetAccount(ctx, acc1) msgs := []sdk.Msg{testdata.NewTestMsg(addr0)} feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() @@ -470,8 +471,8 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() { { "signer does not have enough funds to pay the fee", func() { - err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 149))) - suite.Require().NoError(err) + err := testutil.FundAccount(s.app.BankKeeper, ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 149))) + s.Require().NoError(err) }, false, false, @@ -482,13 +483,13 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() { func() { accNums = []uint64{acc1.GetAccountNumber()} - modAcc := suite.app.AccountKeeper.GetModuleAccount(suite.ctx, types.FeeCollectorName) + modAcc := s.app.AccountKeeper.GetModuleAccount(ctx, types.FeeCollectorName) - suite.Require().True(suite.app.BankKeeper.GetAllBalances(suite.ctx, modAcc.GetAddress()).Empty()) - require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, addr0).AmountOf("atom"), sdk.NewInt(149))) + s.Require().True(s.app.BankKeeper.GetAllBalances(ctx, modAcc.GetAddress()).Empty()) + require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, addr0).AmountOf("atom"), sdk.NewInt(149))) - err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 1))) - suite.Require().NoError(err) + err := testutil.FundAccount(s.app.BankKeeper, ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 1))) + s.Require().NoError(err) }, false, true, @@ -497,10 +498,10 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() { { "signer doesn't have any more funds", func() { - modAcc := suite.app.AccountKeeper.GetModuleAccount(suite.ctx, types.FeeCollectorName) + modAcc := s.app.AccountKeeper.GetModuleAccount(ctx, types.FeeCollectorName) - require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, modAcc.GetAddress()).AmountOf("atom"), sdk.NewInt(150))) - require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, addr0).AmountOf("atom"), sdk.NewInt(0))) + require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, modAcc.GetAddress()).AmountOf("atom"), sdk.NewInt(150))) + require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, addr0).AmountOf("atom"), sdk.NewInt(0))) }, false, false, @@ -509,22 +510,21 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test logic around memo gas consumption. -func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerMemoGas() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(1) + accounts := s.createTestAccounts(ctx, 1) msgs := []sdk.Msg{testdata.NewTestMsg(accounts[0].acc.GetAddress())} privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} @@ -550,7 +550,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { func() { feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) gasLimit = 801 - suite.txBuilder.SetMemo("abcininasidniandsinasindiansdiansdinaisndiasndiadninsd") + txBuilder.SetMemo("abcininasidniandsinasindiansdiansdinaisndiasndiadninsd") }, false, false, @@ -561,7 +561,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { func() { feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) gasLimit = 50000 - suite.txBuilder.SetMemo(strings.Repeat("01234567890", 500)) + txBuilder.SetMemo(strings.Repeat("01234567890", 500)) }, false, false, @@ -572,7 +572,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { func() { feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) gasLimit = 50000 - suite.txBuilder.SetMemo(strings.Repeat("0123456789", 10)) + txBuilder.SetMemo(strings.Repeat("0123456789", 10)) }, false, true, @@ -581,20 +581,20 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } -func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerMultiSigner() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(3) + accounts := s.createTestAccounts(ctx, 3) msg1 := testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress()) msg2 := testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress()) msg3 := testdata.NewTestMsg(accounts[1].acc.GetAddress(), accounts[2].acc.GetAddress()) @@ -615,7 +615,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() { func() { msgs = []sdk.Msg{msg1, msg2, msg3} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv, accounts[1].priv, accounts[2].priv}, []uint64{0, 1, 2}, []uint64{0, 0, 0} - suite.txBuilder.SetMemo("Check signers are in expected order and different account numbers works") + txBuilder.SetMemo("Check signers are in expected order and different account numbers works") }, false, true, @@ -654,20 +654,20 @@ func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } -func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerBadSignBytes() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(2) + accounts := s.createTestAccounts(ctx, 2) msg0 := testdata.NewTestMsg(accounts[0].acc.GetAddress()) // Variable data per test case @@ -685,7 +685,7 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() { { "test good tx and signBytes", func() { - chainID = suite.ctx.ChainID() + chainID = ctx.ChainID() feeAmount = testdata.NewTestFeeAmount() gasLimit = testdata.NewTestGasLimit() msgs = []sdk.Msg{msg0} @@ -708,7 +708,7 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() { { "test wrong accSeqs", func() { - chainID = suite.ctx.ChainID() // Back to correct chainID + chainID = ctx.ChainID() // Back to correct chainID accSeqs = []uint64{2} }, false, @@ -780,20 +780,20 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, chainID, tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, chainID, tc) }) } } -func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerSetPubKey() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(2) + accounts := s.createTestAccounts(ctx, 2) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() @@ -820,8 +820,8 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { "make sure public key has been set (tx itself should fail because of replay protection)", func() { // Make sure public key has been set from previous test. - acc0 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[0].acc.GetAddress()) - suite.Require().Equal(acc0.GetPubKey(), accounts[0].priv.PubKey()) + acc0 := s.app.AccountKeeper.GetAccount(ctx, accounts[0].acc.GetAddress()) + s.Require().Equal(acc0.GetPubKey(), accounts[0].priv.PubKey()) }, false, false, @@ -841,30 +841,30 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { "make sure public key is not set, when tx has no pubkey or signature", func() { // Make sure public key has not been set from previous test. - acc1 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) - suite.Require().Nil(acc1.GetPubKey()) + acc1 := s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress()) + s.Require().Nil(acc1.GetPubKey()) privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{0} msgs = []sdk.Msg{testdata.NewTestMsg(accounts[1].acc.GetAddress())} - suite.txBuilder.SetMsgs(msgs...) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + txBuilder.SetMsgs(msgs...) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) // Manually create tx, and remove signature. - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx) - suite.Require().NoError(err) - suite.Require().NoError(txBuilder.SetSignatures()) + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + txBuilder, err := s.clientCtx.TxConfig.WrapTxBuilder(tx) + s.Require().NoError(err) + s.Require().NoError(txBuilder.SetSignatures()) - // Run anteHandler manually, expect ErrNoSignatures. - _, err = suite.anteHandler(suite.ctx, txBuilder.GetTx(), false) - suite.Require().Error(err) - suite.Require().True(errors.Is(err, sdkerrors.ErrNoSignatures)) + // Run txHandler manually, expect ErrNoSignatures. + _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), txBuilder.GetTx(), abci.RequestCheckTx{}) + s.Require().Error(err) + s.Require().True(errors.Is(err, sdkerrors.ErrNoSignatures)) // Make sure public key has not been set. - acc1 = suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) - suite.Require().Nil(acc1.GetPubKey()) + acc1 = s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress()) + s.Require().Nil(acc1.GetPubKey()) // Set incorrect accSeq, to generate incorrect signature. privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{1} @@ -876,10 +876,10 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { { "make sure previous public key has been set after wrong signature", func() { - // Make sure public key has been set, as SetPubKeyDecorator - // is called before all signature verification decorators. - acc1 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) - suite.Require().Equal(acc1.GetPubKey(), accounts[1].priv.PubKey()) + // Make sure public key has been set, as SetPubKeyMiddleware + // is called before all signature verification middlewares. + acc1 := s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress()) + s.Require().Equal(acc1.GetPubKey(), accounts[1].priv.PubKey()) }, false, false, @@ -888,11 +888,10 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } @@ -962,16 +961,17 @@ func TestCountSubkeys(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(T *testing.T) { - require.Equal(t, tc.want, ante.CountSubKeys(tc.args.pub)) + require.Equal(t, tc.want, middleware.CountSubKeys(tc.args.pub)) }) } } -func (suite *AnteTestSuite) TestAnteHandlerSigLimitExceeded() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerSigLimitExceeded() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(8) + accounts := s.createTestAccounts(ctx, 8) var addrs []sdk.AccAddress var privs []cryptotypes.PrivKey for i := 0; i < 8; i++ { @@ -994,26 +994,25 @@ func (suite *AnteTestSuite) TestAnteHandlerSigLimitExceeded() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc) }) } } // Test custom SignatureVerificationGasConsumer -func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() { - suite.SetupTest(false) // setup - - // setup an ante handler that only accepts PubKeyEd25519 - anteHandler, err := ante.NewAnteHandler( - ante.HandlerOptions{ - AccountKeeper: suite.app.AccountKeeper, - BankKeeper: suite.app.BankKeeper, - FeegrantKeeper: suite.app.FeeGrantKeeper, - SignModeHandler: suite.clientCtx.TxConfig.SignModeHandler(), +func (s *MWTestSuite) TestCustomSignatureVerificationGasConsumer() { + ctx := s.SetupTest(false) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + + txHandler, err := middleware.NewDefaultTxHandler( + middleware.TxHandlerOptions{ + AccountKeeper: s.app.AccountKeeper, + BankKeeper: s.app.BankKeeper, + FeegrantKeeper: s.app.FeeGrantKeeper, + SignModeHandler: s.clientCtx.TxConfig.SignModeHandler(), SigGasConsumer: func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error { switch pubkey := sig.PubKey.(type) { case *ed25519.PubKey: @@ -1025,19 +1024,19 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() { }, }, ) + s.Require().NoError(err) - suite.Require().NoError(err) - suite.anteHandler = anteHandler + s.Require().NoError(err) // Same data for every test cases - accounts := suite.CreateTestAccounts(1) - feeAmount := testdata.NewTestFeeAmount() - gasLimit := testdata.NewTestGasLimit() + accounts := s.createTestAccounts(ctx, 1) + txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) + txBuilder.SetGasLimit(testdata.NewTestGasLimit()) + txBuilder.SetMsgs(testdata.NewTestMsg(accounts[0].acc.GetAddress())) // Variable data per test case var ( accNums []uint64 - msgs []sdk.Msg privs []cryptotypes.PrivKey accSeqs []uint64 ) @@ -1046,7 +1045,6 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() { { "verify that an secp256k1 account gets rejected", func() { - msgs = []sdk.Msg{testdata.NewTestMsg(accounts[0].acc.GetAddress())} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} }, false, @@ -1056,54 +1054,57 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() { } for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { tc.malleate() - suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) + tx, txBytes, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{Tx: txBytes}) + s.Require().Error(err) + s.Require().True(errors.Is(err, tc.expErr)) }) } } -func (suite *AnteTestSuite) TestAnteHandlerReCheck() { - suite.SetupTest(false) // setup +func (s *MWTestSuite) TestTxHandlerReCheck() { + ctx := s.SetupTest(false) // setup // Set recheck=true - suite.ctx = suite.ctx.WithIsReCheckTx(true) - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() + ctx = ctx.WithIsReCheckTx(true) + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Same data for every test cases - accounts := suite.CreateTestAccounts(1) + accounts := s.createTestAccounts(ctx, 1) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) msg := testdata.NewTestMsg(accounts[0].acc.GetAddress()) msgs := []sdk.Msg{msg} - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) - suite.txBuilder.SetMemo("thisisatestmemo") + txBuilder.SetMemo("thisisatestmemo") // test that operations skipped on recheck do not run privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) - // make signature array empty which would normally cause ValidateBasicDecorator and SigVerificationDecorator fail - // since these decorators don't run on recheck, the tx should pass the antehandler - txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx) - suite.Require().NoError(err) - suite.Require().NoError(txBuilder.SetSignatures()) + // make signature array empty which would normally cause ValidateBasicMiddleware and SigVerificationMiddleware fail + // since these middlewares don't run on recheck, the tx should pass the middleware + txBuilder, err = s.clientCtx.TxConfig.WrapTxBuilder(tx) + s.Require().NoError(err) + s.Require().NoError(txBuilder.SetSignatures()) - _, err = suite.anteHandler(suite.ctx, txBuilder.GetTx(), false) - suite.Require().Nil(err, "AnteHandler errored on recheck unexpectedly: %v", err) + _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), txBuilder.GetTx(), abci.RequestCheckTx{Type: abci.CheckTxType_Recheck}) + s.Require().Nil(err, "TxHandler errored on recheck unexpectedly: %v", err) - tx, err = suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + tx, _, err = s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) txBytes, err := json.Marshal(tx) - suite.Require().Nil(err, "Error marshalling tx: %v", err) - suite.ctx = suite.ctx.WithTxBytes(txBytes) + s.Require().Nil(err, "Error marshalling tx: %v", err) + ctx = ctx.WithTxBytes(txBytes) // require that state machine param-dependent checking is still run on recheck since parameters can change between check and recheck testCases := []struct { @@ -1114,35 +1115,37 @@ func (suite *AnteTestSuite) TestAnteHandlerReCheck() { {"txsize check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, 10000000, types.DefaultSigVerifyCostED25519, types.DefaultSigVerifyCostSecp256k1)}, {"sig verify cost check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, types.DefaultTxSizeCostPerByte, types.DefaultSigVerifyCostED25519, 100000000)}, } + for _, tc := range testCases { // set testcase parameters - suite.app.AccountKeeper.SetParams(suite.ctx, tc.params) + s.app.AccountKeeper.SetParams(ctx, tc.params) - _, err := suite.anteHandler(suite.ctx, tx, false) + _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{Tx: txBytes, Type: abci.CheckTxType_Recheck}) - suite.Require().NotNil(err, "tx does not fail on recheck with updated params in test case: %s", tc.name) + s.Require().NotNil(err, "tx does not fail on recheck with updated params in test case: %s", tc.name) // reset parameters to default values - suite.app.AccountKeeper.SetParams(suite.ctx, types.DefaultParams()) + s.app.AccountKeeper.SetParams(ctx, types.DefaultParams()) } // require that local mempool fee check is still run on recheck since validator may change minFee between check and recheck - // create new minimum gas price so antehandler fails on recheck - suite.ctx = suite.ctx.WithMinGasPrices([]sdk.DecCoin{{ + // create new minimum gas price so txhandler fails on recheck + ctx = ctx.WithMinGasPrices([]sdk.DecCoin{{ Denom: "dnecoin", // fee does not have this denom Amount: sdk.NewDec(5), }}) - _, err = suite.anteHandler(suite.ctx, tx, false) - suite.Require().NotNil(err, "antehandler on recheck did not fail when mingasPrice was changed") + _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{}) + + s.Require().NotNil(err, "txhandler on recheck did not fail when mingasPrice was changed") // reset min gasprice - suite.ctx = suite.ctx.WithMinGasPrices(sdk.DecCoins{}) + ctx = ctx.WithMinGasPrices(sdk.DecCoins{}) - // remove funds for account so antehandler fails on recheck - suite.app.AccountKeeper.SetAccount(suite.ctx, accounts[0].acc) - balances := suite.app.BankKeeper.GetAllBalances(suite.ctx, accounts[0].acc.GetAddress()) - err = suite.app.BankKeeper.SendCoinsFromAccountToModule(suite.ctx, accounts[0].acc.GetAddress(), minttypes.ModuleName, balances) - suite.Require().NoError(err) + // remove funds for account so txhandler fails on recheck + s.app.AccountKeeper.SetAccount(ctx, accounts[0].acc) + balances := s.app.BankKeeper.GetAllBalances(ctx, accounts[0].acc.GetAddress()) + err = s.app.BankKeeper.SendCoinsFromAccountToModule(ctx, accounts[0].acc.GetAddress(), minttypes.ModuleName, balances) + s.Require().NoError(err) - _, err = suite.anteHandler(suite.ctx, tx, false) - suite.Require().NotNil(err, "antehandler on recheck did not fail once feePayer no longer has sufficient funds") + _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{}) + s.Require().NotNil(err, "txhandler on recheck did not fail once feePayer no longer has sufficient funds") } diff --git a/x/auth/middleware/msg_service_router_test.go b/x/auth/middleware/msg_service_router_test.go index 6223293617e4..ca6ec79b5b9a 100644 --- a/x/auth/middleware/msg_service_router_test.go +++ b/x/auth/middleware/msg_service_router_test.go @@ -1,22 +1,13 @@ package middleware_test import ( - "os" "testing" "github.com/stretchr/testify/require" - abci "github.com/tendermint/tendermint/abci/types" - "github.com/tendermint/tendermint/libs/log" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" - dbm "github.com/tendermint/tm-db" - "github.com/cosmos/cosmos-sdk/baseapp" - "github.com/cosmos/cosmos-sdk/client/tx" "github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/testutil/testdata" - "github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/x/auth/middleware" - authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" ) func TestRegisterMsgService(t *testing.T) { @@ -62,63 +53,3 @@ func TestRegisterMsgServiceTwice(t *testing.T) { ) }) } - -func TestMsgService(t *testing.T) { - priv, _, _ := testdata.KeyTestPubAddr() - encCfg := simapp.MakeTestEncodingConfig() - testdata.RegisterInterfaces(encCfg.InterfaceRegistry) - db := dbm.NewMemDB() - app := baseapp.NewBaseApp("test", log.NewTMLogger(log.NewSyncWriter(os.Stdout)), db, encCfg.TxConfig.TxDecoder()) - app.SetInterfaceRegistry(encCfg.InterfaceRegistry) - msr := middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry) - txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ - MsgServiceRouter: msr, - }) - require.NoError(t, err) - app.SetTxHandler(txHandler) - testdata.RegisterMsgServer( - msr, - testdata.MsgServerImpl{}, - ) - _ = app.BeginBlock(abci.RequestBeginBlock{Header: tmproto.Header{Height: 1}}) - - msg := testdata.MsgCreateDog{Dog: &testdata.Dog{Name: "Spot"}} - txBuilder := encCfg.TxConfig.NewTxBuilder() - txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) - txBuilder.SetGasLimit(testdata.NewTestGasLimit()) - err = txBuilder.SetMsgs(&msg) - require.NoError(t, err) - - // First round: we gather all the signer infos. We use the "set empty - // signature" hack to do that. - sigV2 := signing.SignatureV2{ - PubKey: priv.PubKey(), - Data: &signing.SingleSignatureData{ - SignMode: encCfg.TxConfig.SignModeHandler().DefaultMode(), - Signature: nil, - }, - Sequence: 0, - } - - err = txBuilder.SetSignatures(sigV2) - require.NoError(t, err) - - // Second round: all signer infos are set, so each signer can sign. - signerData := authsigning.SignerData{ - ChainID: "test", - AccountNumber: 0, - Sequence: 0, - } - sigV2, err = tx.SignWithPrivKey( - encCfg.TxConfig.SignModeHandler().DefaultMode(), signerData, - txBuilder, priv, encCfg.TxConfig, 0) - require.NoError(t, err) - err = txBuilder.SetSignatures(sigV2) - require.NoError(t, err) - - // Send the tx to the app - txBytes, err := encCfg.TxConfig.TxEncoder()(txBuilder.GetTx()) - require.NoError(t, err) - res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) - require.Equal(t, abci.CodeTypeOK, res.Code, "res=%+v", res) -} diff --git a/x/auth/middleware/run_msgs.go b/x/auth/middleware/run_msgs.go index d1ae2960369b..e9073b59a9af 100644 --- a/x/auth/middleware/run_msgs.go +++ b/x/auth/middleware/run_msgs.go @@ -83,6 +83,7 @@ func (txh runMsgsTxHandler) runMsgs(sdkCtx sdk.Context, msgs []sdk.Msg, txBytes Data: make([]*sdk.MsgData, 0, len(msgs)), } + // NOTE: GasWanted is determined by the Gas TxHandler and GasUsed by the GasMeter. for i, msg := range msgs { var ( msgResult *sdk.Result diff --git a/x/auth/middleware/run_msgs_test.go b/x/auth/middleware/run_msgs_test.go new file mode 100644 index 000000000000..909de101d46d --- /dev/null +++ b/x/auth/middleware/run_msgs_test.go @@ -0,0 +1,36 @@ +package middleware_test + +import ( + "github.com/tendermint/tendermint/abci/types" + + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + "github.com/cosmos/cosmos-sdk/testutil/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" +) + +func (s *MWTestSuite) TestRunMsgs() { + ctx := s.SetupTest(true) // setup + + msr := middleware.NewMsgServiceRouter(s.clientCtx.InterfaceRegistry) + testdata.RegisterMsgServer(msr, testdata.MsgServerImpl{}) + txHandler := middleware.NewRunMsgsTxHandler(msr, nil) + + priv, _, _ := testdata.KeyTestPubAddr() + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + txBuilder.SetMsgs(&testdata.MsgCreateDog{Dog: &testdata.Dog{Name: "Spot"}}) + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv}, []uint64{0}, []uint64{0} + tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + txBytes, err := s.clientCtx.TxConfig.TxEncoder()(tx) + s.Require().NoError(err) + + res, err := txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, types.RequestDeliverTx{Tx: txBytes}) + s.Require().NoError(err) + s.Require().NotEmpty(res.Data) + var txMsgData sdk.TxMsgData + err = s.clientCtx.Codec.Unmarshal(res.Data, &txMsgData) + s.Require().NoError(err) + s.Require().Len(txMsgData.Data, 1) + s.Require().Equal(sdk.MsgTypeURL(&testdata.MsgCreateDog{}), txMsgData.Data[0].MsgType) +} diff --git a/x/auth/ante/sigverify.go b/x/auth/middleware/sigverify.go similarity index 52% rename from x/auth/ante/sigverify.go rename to x/auth/middleware/sigverify.go index 5097478da237..9ae464932684 100644 --- a/x/auth/ante/sigverify.go +++ b/x/auth/middleware/sigverify.go @@ -1,9 +1,9 @@ -package ante +package middleware import ( "bytes" + "context" "encoding/base64" - "encoding/hex" "fmt" "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" @@ -14,10 +14,11 @@ import ( "github.com/cosmos/cosmos-sdk/crypto/types/multisig" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" ) var ( @@ -25,44 +26,42 @@ var ( key = make([]byte, secp256k1.PubKeySize) simSecp256k1Pubkey = &secp256k1.PubKey{Key: key} simSecp256k1Sig [64]byte - - _ authsigning.SigVerifiableTx = (*legacytx.StdTx)(nil) // assert StdTx implements SigVerifiableTx ) -func init() { - // This decodes a valid hex string into a sepc256k1Pubkey for use in transaction simulation - bz, _ := hex.DecodeString("035AD6810A47F073553FF30D2FCC7E0D3B1C0B74B61A1AAA2582344037151E143A") - copy(key, bz) - simSecp256k1Pubkey.Key = key -} - // SignatureVerificationGasConsumer is the type of function that is used to both // consume gas when verifying signatures and also to accept or reject different types of pubkeys // This is where apps can define their own PubKey type SignatureVerificationGasConsumer = func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error -// SetPubKeyDecorator sets PubKeys in context for any signer which does not already have pubkey set -// PubKeys must be set in context for all signers before any other sigverify decorators run -// CONTRACT: Tx must implement SigVerifiableTx interface -type SetPubKeyDecorator struct { - ak AccountKeeper +var _ tx.Handler = setPubKeyTxHandler{} + +type setPubKeyTxHandler struct { + ak AccountKeeper + next tx.Handler } -func NewSetPubKeyDecorator(ak AccountKeeper) SetPubKeyDecorator { - return SetPubKeyDecorator{ - ak: ak, +// SetPubKeyMiddleware sets PubKeys in context for any signer which does not already have pubkey set +// PubKeys must be set in context for all signers before any other sigverify middlewares run +// CONTRACT: Tx must implement SigVerifiableTx interface +func SetPubKeyMiddleware(ak AccountKeeper) tx.Middleware { + return func(txh tx.Handler) tx.Handler { + return setPubKeyTxHandler{ + ak: ak, + next: txh, + } } } -func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { +func (spkm setPubKeyTxHandler) setPubKey(ctx context.Context, tx sdk.Tx, simulate bool) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) sigTx, ok := tx.(authsigning.SigVerifiableTx) if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") } pubkeys, err := sigTx.GetPubKeys() if err != nil { - return ctx, err + return err } signers := sigTx.GetSigners() @@ -76,13 +75,13 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b } // Only make check if simulate=false if !simulate && !bytes.Equal(pk.Address(), signers[i]) { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey, + return sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey, "pubKey does not match signer address %s with signer index: %d", signers[i], i) } - acc, err := GetSignerAcc(ctx, spkd.ak, signers[i]) + acc, err := GetSignerAcc(sdkCtx, spkm.ak, signers[i]) if err != nil { - return ctx, err + return err } // account already has pubkey set,no need to reset if acc.GetPubKey() != nil { @@ -90,9 +89,9 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b } err = acc.SetPubKey(pk) if err != nil { - return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, err.Error()) + return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, err.Error()) } - spkd.ak.SetAccount(ctx, acc) + spkm.ak.SetAccount(sdkCtx, acc) } // Also emit the following events, so that txs can be indexed by these @@ -101,7 +100,7 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b // - concat(address,"/",sequence) (via `tx.acc_seq='cosmos1abc...def/42'`). sigs, err := sigTx.GetSignaturesV2() if err != nil { - return ctx, err + return err } var events sdk.Events @@ -112,7 +111,7 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b sigBzs, err := signatureDataToBz(sig.Data) if err != nil { - return ctx, err + return err } for _, sigBz := range sigBzs { events = append(events, sdk.NewEvent(sdk.EventTypeTx, @@ -121,37 +120,206 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b } } - ctx.EventManager().EmitEvents(events) + sdkCtx.EventManager().EmitEvents(events) + + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (spkm setPubKeyTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := spkm.setPubKey(ctx, tx, false); err != nil { + return abci.ResponseCheckTx{}, err + } - return next(ctx, tx, simulate) + return spkm.next.CheckTx(ctx, tx, req) } -// Consume parameter-defined amount of gas for each signature according to the passed-in SignatureVerificationGasConsumer function -// before calling the next AnteHandler -// CONTRACT: Pubkeys are set in context for all signers before this decorator runs +// DeliverTx implements tx.Handler.DeliverTx. +func (spkm setPubKeyTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := spkm.setPubKey(ctx, tx, false); err != nil { + return abci.ResponseDeliverTx{}, err + } + return spkm.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (spkm setPubKeyTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := spkm.setPubKey(ctx, sdkTx, true); err != nil { + return tx.ResponseSimulateTx{}, err + } + return spkm.next.SimulateTx(ctx, sdkTx, req) +} + +var _ tx.Handler = validateSigCountTxHandler{} + +type validateSigCountTxHandler struct { + ak AccountKeeper + next tx.Handler +} + +// ValidateSigCountMiddleware takes in Params and returns errors if there are too many signatures in the tx for the given params +// otherwise it calls next middleware +// Use this middleware to set parameterized limit on number of signatures in tx // CONTRACT: Tx must implement SigVerifiableTx interface -type SigGasConsumeDecorator struct { +func ValidateSigCountMiddleware(ak AccountKeeper) tx.Middleware { + return func(txh tx.Handler) tx.Handler { + return validateSigCountTxHandler{ + ak: ak, + next: txh, + } + } +} + +func (vscd validateSigCountTxHandler) checkSigCount(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + + sigTx, ok := tx.(authsigning.SigVerifiableTx) + if !ok { + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a sigTx") + } + + params := vscd.ak.GetParams(sdkCtx) + pubKeys, err := sigTx.GetPubKeys() + if err != nil { + return err + } + + sigCount := 0 + for _, pk := range pubKeys { + sigCount += CountSubKeys(pk) + if uint64(sigCount) > params.TxSigLimit { + return sdkerrors.Wrapf(sdkerrors.ErrTooManySignatures, + "signatures: %d, limit: %d", sigCount, params.TxSigLimit) + } + } + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (vscd validateSigCountTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := vscd.checkSigCount(ctx, tx); err != nil { + return abci.ResponseCheckTx{}, err + } + + return vscd.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (vscd validateSigCountTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := vscd.checkSigCount(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return vscd.next.SimulateTx(ctx, sdkTx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (vscd validateSigCountTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := vscd.checkSigCount(ctx, tx); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return vscd.next.DeliverTx(ctx, tx, req) +} + +// DefaultSigVerificationGasConsumer is the default implementation of SignatureVerificationGasConsumer. It consumes gas +// for signature verification based upon the public key type. The cost is fetched from the given params and is matched +// by the concrete type. +func DefaultSigVerificationGasConsumer( + meter sdk.GasMeter, sig signing.SignatureV2, params types.Params, +) error { + pubkey := sig.PubKey + switch pubkey := pubkey.(type) { + case *ed25519.PubKey: + meter.ConsumeGas(params.SigVerifyCostED25519, "ante verify: ed25519") + return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "ED25519 public keys are unsupported") + + case *secp256k1.PubKey: + meter.ConsumeGas(params.SigVerifyCostSecp256k1, "ante verify: secp256k1") + return nil + + case *secp256r1.PubKey: + meter.ConsumeGas(params.SigVerifyCostSecp256r1(), "ante verify: secp256r1") + return nil + + case multisig.PubKey: + multisignature, ok := sig.Data.(*signing.MultiSignatureData) + if !ok { + return fmt.Errorf("expected %T, got, %T", &signing.MultiSignatureData{}, sig.Data) + } + err := ConsumeMultisignatureVerificationGas(meter, multisignature, pubkey, params, sig.Sequence) + if err != nil { + return err + } + return nil + + default: + return sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey, "unrecognized public key type: %T", pubkey) + } +} + +// ConsumeMultisignatureVerificationGas consumes gas from a GasMeter for verifying a multisig pubkey signature +func ConsumeMultisignatureVerificationGas( + meter sdk.GasMeter, sig *signing.MultiSignatureData, pubkey multisig.PubKey, + params types.Params, accSeq uint64, +) error { + + size := sig.BitArray.Count() + sigIndex := 0 + + for i := 0; i < size; i++ { + if !sig.BitArray.GetIndex(i) { + continue + } + sigV2 := signing.SignatureV2{ + PubKey: pubkey.GetPubKeys()[i], + Data: sig.Signatures[sigIndex], + Sequence: accSeq, + } + err := DefaultSigVerificationGasConsumer(meter, sigV2, params) + if err != nil { + return err + } + sigIndex++ + } + + return nil +} + +var _ tx.Handler = sigGasConsumeTxHandler{} + +type sigGasConsumeTxHandler struct { ak AccountKeeper sigGasConsumer SignatureVerificationGasConsumer + next tx.Handler } -func NewSigGasConsumeDecorator(ak AccountKeeper, sigGasConsumer SignatureVerificationGasConsumer) SigGasConsumeDecorator { - return SigGasConsumeDecorator{ - ak: ak, - sigGasConsumer: sigGasConsumer, +// SigGasConsumeMiddleware consumes parameter-defined amount of gas for each signature according to the passed-in SignatureVerificationGasConsumer function +// before calling the next middleware +// CONTRACT: Pubkeys are set in context for all signers before this middleware runs +// CONTRACT: Tx must implement SigVerifiableTx interface +func SigGasConsumeMiddleware(ak AccountKeeper, sigGasConsumer SignatureVerificationGasConsumer) tx.Middleware { + return func(h tx.Handler) tx.Handler { + return sigGasConsumeTxHandler{ + ak: ak, + sigGasConsumer: sigGasConsumer, + next: h, + } } } -func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { +func (sgcm sigGasConsumeTxHandler) sigGasConsume(ctx context.Context, tx sdk.Tx, simulate bool) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + sigTx, ok := tx.(authsigning.SigVerifiableTx) if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } - params := sgcd.ak.GetParams(ctx) + params := sgcm.ak.GetParams(sdkCtx) sigs, err := sigTx.GetSignaturesV2() if err != nil { - return ctx, err + return err } // stdSigs contains the sequence number, account number, and signatures. @@ -159,9 +327,9 @@ func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simula signerAddrs := sigTx.GetSigners() for i, sig := range sigs { - signerAcc, err := GetSignerAcc(ctx, sgcd.ak, signerAddrs[i]) + signerAcc, err := GetSignerAcc(sdkCtx, sgcm.ak, signerAddrs[i]) if err != nil { - return ctx, err + return err } pubKey := signerAcc.GetPubKey() @@ -181,29 +349,62 @@ func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simula Sequence: sig.Sequence, } - err = sgcd.sigGasConsumer(ctx.GasMeter(), sig, params) + err = sgcm.sigGasConsumer(sdkCtx.GasMeter(), sig, params) if err != nil { - return ctx, err + return err } } - return next(ctx, tx, simulate) + return nil } -// Verify all signatures for a tx and return an error if any are invalid. Note, -// the SigVerificationDecorator decorator will not get executed on ReCheck. -// -// CONTRACT: Pubkeys are set in context for all signers before this decorator runs -// CONTRACT: Tx must implement SigVerifiableTx interface -type SigVerificationDecorator struct { +// CheckTx implements tx.Handler.CheckTx. +func (sgcm sigGasConsumeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := sgcm.sigGasConsume(ctx, tx, false); err != nil { + return abci.ResponseCheckTx{}, err + } + + return sgcm.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (sgcm sigGasConsumeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := sgcm.sigGasConsume(ctx, tx, false); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return sgcm.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (sgcm sigGasConsumeTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := sgcm.sigGasConsume(ctx, sdkTx, true); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return sgcm.next.SimulateTx(ctx, sdkTx, req) +} + +var _ tx.Handler = sigVerificationTxHandler{} + +type sigVerificationTxHandler struct { ak AccountKeeper signModeHandler authsigning.SignModeHandler + next tx.Handler } -func NewSigVerificationDecorator(ak AccountKeeper, signModeHandler authsigning.SignModeHandler) SigVerificationDecorator { - return SigVerificationDecorator{ - ak: ak, - signModeHandler: signModeHandler, +// SigVerificationMiddleware verifies all signatures for a tx and return an error if any are invalid. Note, +// the sigVerificationTxHandler middleware will not get executed on ReCheck. +// +// CONTRACT: Pubkeys are set in context for all signers before this middleware runs +// CONTRACT: Tx must implement SigVerifiableTx interface +func SigVerificationMiddleware(ak AccountKeeper, signModeHandler authsigning.SignModeHandler) tx.Middleware { + return func(h tx.Handler) tx.Handler { + return sigVerificationTxHandler{ + ak: ak, + signModeHandler: signModeHandler, + next: h, + } } } @@ -211,7 +412,7 @@ func NewSigVerificationDecorator(ak AccountKeeper, signModeHandler authsigning.S // signers are using SIGN_MODE_LEGACY_AMINO_JSON. If this is the case // then the corresponding SignatureV2 struct will not have account sequence // explicitly set, and we should skip the explicit verification of sig.Sequence -// in the SigVerificationDecorator's AnteHandler function. +// in the SigVerificationMiddleware's middleware function. func OnlyLegacyAminoSigners(sigData signing.SignatureData) bool { switch v := sigData.(type) { case *signing.SingleSignatureData: @@ -228,53 +429,54 @@ func OnlyLegacyAminoSigners(sigData signing.SignatureData) bool { } } -func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { +func (svm sigVerificationTxHandler) sigVerify(ctx context.Context, tx sdk.Tx, isReCheckTx, simulate bool) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) // no need to verify signatures on recheck tx - if ctx.IsReCheckTx() { - return next(ctx, tx, simulate) + if isReCheckTx { + return nil } sigTx, ok := tx.(authsigning.SigVerifiableTx) if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } // stdSigs contains the sequence number, account number, and signatures. // When simulating, this would just be a 0-length slice. sigs, err := sigTx.GetSignaturesV2() if err != nil { - return ctx, err + return err } signerAddrs := sigTx.GetSigners() // check that signer length and signature length are the same if len(sigs) != len(signerAddrs) { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "invalid number of signer; expected: %d, got %d", len(signerAddrs), len(sigs)) + return sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "invalid number of signer; expected: %d, got %d", len(signerAddrs), len(sigs)) } for i, sig := range sigs { - acc, err := GetSignerAcc(ctx, svd.ak, signerAddrs[i]) + acc, err := GetSignerAcc(sdkCtx, svm.ak, signerAddrs[i]) if err != nil { - return ctx, err + return err } // retrieve pubkey pubKey := acc.GetPubKey() if !simulate && pubKey == nil { - return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "pubkey on account is not set") + return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "pubkey on account is not set") } // Check account sequence number. if sig.Sequence != acc.GetSequence() { - return ctx, sdkerrors.Wrapf( + return sdkerrors.Wrapf( sdkerrors.ErrWrongSequence, "account sequence mismatch, expected %d, got %d", acc.GetSequence(), sig.Sequence, ) } // retrieve signer data - genesis := ctx.BlockHeight() == 0 - chainID := ctx.ChainID() + genesis := sdkCtx.BlockHeight() == 0 + chainID := sdkCtx.ChainID() var accNum uint64 if !genesis { accNum = acc.GetAccountNumber() @@ -286,7 +488,7 @@ func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simul } if !simulate { - err := authsigning.VerifySignature(pubKey, signerData, sig.Data, svd.signModeHandler, tx) + err := authsigning.VerifySignature(pubKey, signerData, sig.Data, svm.signModeHandler, tx) if err != nil { var errMsg string if OnlyLegacyAminoSigners(sig.Data) { @@ -296,153 +498,112 @@ func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simul } else { errMsg = fmt.Sprintf("signature verification failed; please verify account number (%d) and chain-id (%s)", accNum, chainID) } - return ctx, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, errMsg) + return sdkerrors.Wrap(sdkerrors.ErrUnauthorized, errMsg) } } } - return next(ctx, tx, simulate) + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (svd sigVerificationTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := svd.sigVerify(ctx, tx, req.Type == abci.CheckTxType_Recheck, false); err != nil { + return abci.ResponseCheckTx{}, err + } + + return svd.next.CheckTx(ctx, tx, req) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (svd sigVerificationTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := svd.sigVerify(ctx, tx, false, false); err != nil { + return abci.ResponseDeliverTx{}, err + } + + return svd.next.DeliverTx(ctx, tx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (svd sigVerificationTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := svd.sigVerify(ctx, sdkTx, false, true); err != nil { + return tx.ResponseSimulateTx{}, err + } + + return svd.next.SimulateTx(ctx, sdkTx, req) +} + +var _ tx.Handler = incrementSequenceTxHandler{} + +type incrementSequenceTxHandler struct { + ak AccountKeeper + next tx.Handler } -// IncrementSequenceDecorator handles incrementing sequences of all signers. -// Use the IncrementSequenceDecorator decorator to prevent replay attacks. Note, -// there is no need to execute IncrementSequenceDecorator on RecheckTX since +// IncrementSequenceMiddleware handles incrementing sequences of all signers. +// Use the incrementSequenceTxHandler middleware to prevent replay attacks. Note, +// there is no need to execute incrementSequenceTxHandler on RecheckTX since // CheckTx would already bump the sequence number. // // NOTE: Since CheckTx and DeliverTx state are managed separately, subsequent and // sequential txs orginating from the same account cannot be handled correctly in // a reliable way unless sequence numbers are managed and tracked manually by a // client. It is recommended to instead use multiple messages in a tx. -type IncrementSequenceDecorator struct { - ak AccountKeeper -} - -func NewIncrementSequenceDecorator(ak AccountKeeper) IncrementSequenceDecorator { - return IncrementSequenceDecorator{ - ak: ak, +func IncrementSequenceMiddleware(ak AccountKeeper) tx.Middleware { + return func(h tx.Handler) tx.Handler { + return incrementSequenceTxHandler{ + ak: ak, + next: h, + } } } -func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { +func (isd incrementSequenceTxHandler) incrementSeq(ctx context.Context, tx sdk.Tx) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) sigTx, ok := tx.(authsigning.SigVerifiableTx) if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") + return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } // increment sequence of all signers for _, addr := range sigTx.GetSigners() { - acc := isd.ak.GetAccount(ctx, addr) + acc := isd.ak.GetAccount(sdkCtx, addr) if err := acc.SetSequence(acc.GetSequence() + 1); err != nil { panic(err) } - isd.ak.SetAccount(ctx, acc) + isd.ak.SetAccount(sdkCtx, acc) } - return next(ctx, tx, simulate) -} - -// ValidateSigCountDecorator takes in Params and returns errors if there are too many signatures in the tx for the given params -// otherwise it calls next AnteHandler -// Use this decorator to set parameterized limit on number of signatures in tx -// CONTRACT: Tx must implement SigVerifiableTx interface -type ValidateSigCountDecorator struct { - ak AccountKeeper -} - -func NewValidateSigCountDecorator(ak AccountKeeper) ValidateSigCountDecorator { - return ValidateSigCountDecorator{ - ak: ak, - } + return nil } -func (vscd ValidateSigCountDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - sigTx, ok := tx.(authsigning.SigVerifiableTx) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a sigTx") +// CheckTx implements tx.Handler.CheckTx. +func (isd incrementSequenceTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { + if err := isd.incrementSeq(ctx, tx); err != nil { + return abci.ResponseCheckTx{}, err } - params := vscd.ak.GetParams(ctx) - pubKeys, err := sigTx.GetPubKeys() - if err != nil { - return ctx, err - } - - sigCount := 0 - for _, pk := range pubKeys { - sigCount += CountSubKeys(pk) - if uint64(sigCount) > params.TxSigLimit { - return ctx, sdkerrors.Wrapf(sdkerrors.ErrTooManySignatures, - "signatures: %d, limit: %d", sigCount, params.TxSigLimit) - } - } - - return next(ctx, tx, simulate) + return isd.next.CheckTx(ctx, tx, req) } -// DefaultSigVerificationGasConsumer is the default implementation of SignatureVerificationGasConsumer. It consumes gas -// for signature verification based upon the public key type. The cost is fetched from the given params and is matched -// by the concrete type. -func DefaultSigVerificationGasConsumer( - meter sdk.GasMeter, sig signing.SignatureV2, params types.Params, -) error { - pubkey := sig.PubKey - switch pubkey := pubkey.(type) { - case *ed25519.PubKey: - meter.ConsumeGas(params.SigVerifyCostED25519, "ante verify: ed25519") - return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "ED25519 public keys are unsupported") - - case *secp256k1.PubKey: - meter.ConsumeGas(params.SigVerifyCostSecp256k1, "ante verify: secp256k1") - return nil - - case *secp256r1.PubKey: - meter.ConsumeGas(params.SigVerifyCostSecp256r1(), "ante verify: secp256r1") - return nil - - case multisig.PubKey: - multisignature, ok := sig.Data.(*signing.MultiSignatureData) - if !ok { - return fmt.Errorf("expected %T, got, %T", &signing.MultiSignatureData{}, sig.Data) - } - err := ConsumeMultisignatureVerificationGas(meter, multisignature, pubkey, params, sig.Sequence) - if err != nil { - return err - } - return nil - - default: - return sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey, "unrecognized public key type: %T", pubkey) +// DeliverTx implements tx.Handler.DeliverTx. +func (isd incrementSequenceTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { + if err := isd.incrementSeq(ctx, tx); err != nil { + return abci.ResponseDeliverTx{}, err } -} -// ConsumeMultisignatureVerificationGas consumes gas from a GasMeter for verifying a multisig pubkey signature -func ConsumeMultisignatureVerificationGas( - meter sdk.GasMeter, sig *signing.MultiSignatureData, pubkey multisig.PubKey, - params types.Params, accSeq uint64, -) error { - - size := sig.BitArray.Count() - sigIndex := 0 + return isd.next.DeliverTx(ctx, tx, req) +} - for i := 0; i < size; i++ { - if !sig.BitArray.GetIndex(i) { - continue - } - sigV2 := signing.SignatureV2{ - PubKey: pubkey.GetPubKeys()[i], - Data: sig.Signatures[sigIndex], - Sequence: accSeq, - } - err := DefaultSigVerificationGasConsumer(meter, sigV2, params) - if err != nil { - return err - } - sigIndex++ +// SimulateTx implements tx.Handler.SimulateTx. +func (isd incrementSequenceTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) { + if err := isd.incrementSeq(ctx, sdkTx); err != nil { + return tx.ResponseSimulateTx{}, err } - return nil + return isd.next.SimulateTx(ctx, sdkTx, req) } // GetSignerAcc returns an account for a given address that is expected to sign diff --git a/x/auth/ante/sigverify_benchmark_test.go b/x/auth/middleware/sigverify_benchmark_test.go similarity index 89% rename from x/auth/ante/sigverify_benchmark_test.go rename to x/auth/middleware/sigverify_benchmark_test.go index 56e596fa6b55..dc635985170b 100644 --- a/x/auth/ante/sigverify_benchmark_test.go +++ b/x/auth/middleware/sigverify_benchmark_test.go @@ -1,4 +1,4 @@ -package ante_test +package middleware_test import ( "testing" @@ -10,7 +10,7 @@ import ( "github.com/cosmos/cosmos-sdk/crypto/keys/secp256r1" ) -// This benchmark is used to asses the ante.Secp256k1ToR1GasFactor value +// This benchmark is used to asses the middleware.Secp256k1ToR1GasFactor value func BenchmarkSig(b *testing.B) { require := require.New(b) msg := tmcrypto.CRandBytes(1000) diff --git a/x/auth/ante/sigverify_test.go b/x/auth/middleware/sigverify_test.go similarity index 53% rename from x/auth/ante/sigverify_test.go rename to x/auth/middleware/sigverify_test.go index 074f4c33afc1..59c46938f619 100644 --- a/x/auth/ante/sigverify_test.go +++ b/x/auth/middleware/sigverify_test.go @@ -1,10 +1,10 @@ -package ante_test +package middleware_test import ( "fmt" "github.com/cosmos/cosmos-sdk/client" - "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/legacy" "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" @@ -14,16 +14,22 @@ import ( "github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/testutil/testdata" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" ) -func (suite *AnteTestSuite) TestSetPubKey() { - suite.SetupTest(true) // setup - require := suite.Require() - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() +func (s *MWTestSuite) TestSetPubKey() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() + require := s.Require() + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.SetPubKeyMiddleware(s.app.AccountKeeper), + ) // keys and addresses priv1, pub1, addr1 := testdata.KeyTestPubAddr() @@ -36,35 +42,32 @@ func (suite *AnteTestSuite) TestSetPubKey() { msgs := make([]sdk.Msg, len(addrs)) // set accounts and create msg for each address for i, addr := range addrs { - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr) require.NoError(acc.SetAccountNumber(uint64(i))) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) + s.app.AccountKeeper.SetAccount(ctx, acc) msgs[i] = testdata.NewTestMsg(addr) } - require.NoError(suite.txBuilder.SetMsgs(msgs...)) - suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) - suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit()) + require.NoError(txBuilder.SetMsgs(msgs...)) + txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) + txBuilder.SetGasLimit(testdata.NewTestGasLimit()) privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0} - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) + testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) require.NoError(err) - spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) - antehandler := sdk.ChainAnteDecorators(spkd) - - ctx, err := antehandler(suite.ctx, tx, false) + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestDeliverTx{}) require.NoError(err) - // Require that all accounts have pubkey set after Decorator runs + // Require that all accounts have pubkey set after middleware runs for i, addr := range addrs { - pk, err := suite.app.AccountKeeper.GetPubKey(ctx, addr) + pk, err := s.app.AccountKeeper.GetPubKey(ctx, addr) require.NoError(err, "Error on retrieving pubkey from account") require.True(pubs[i].Equals(pk), "Wrong Pubkey retrieved from AccountKeeper, idx=%d\nexpected=%s\n got=%s", i, pubs[i], pk) } } -func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() { +func (s *MWTestSuite) TestConsumeSignatureVerificationGas() { params := types.DefaultParams() msg := []byte{1, 2, 3, 4} cdc := simapp.MakeTestEncodingConfig().Amino @@ -78,9 +81,9 @@ func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() { for i := 0; i < len(pkSet1); i++ { stdSig := legacytx.StdSignature{PubKey: pkSet1[i], Signature: sigSet1[i]} sigV2, err := legacytx.StdSignatureToSignatureV2(cdc, stdSig) - suite.Require().NoError(err) + s.Require().NoError(err) err = multisig.AddSignatureV2(multisignature1, sigV2, pkSet1) - suite.Require().NoError(err) + s.Require().NoError(err) } type args struct { @@ -107,23 +110,30 @@ func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() { Data: tt.args.sig, Sequence: 0, // Arbitrary account sequence } - err := ante.DefaultSigVerificationGasConsumer(tt.args.meter, sigV2, tt.args.params) + err := middleware.DefaultSigVerificationGasConsumer(tt.args.meter, sigV2, tt.args.params) if tt.shouldErr { - suite.Require().NotNil(err) + s.Require().NotNil(err) } else { - suite.Require().Nil(err) - suite.Require().Equal(tt.gasConsumed, tt.args.meter.GasConsumed(), fmt.Sprintf("%d != %d", tt.gasConsumed, tt.args.meter.GasConsumed())) + s.Require().Nil(err) + s.Require().Equal(tt.gasConsumed, tt.args.meter.GasConsumed(), fmt.Sprintf("%d != %d", tt.gasConsumed, tt.args.meter.GasConsumed())) } } } -func (suite *AnteTestSuite) TestSigVerification() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() +func (s *MWTestSuite) TestSigVerification() { + ctx := s.SetupTest(true) // setup // make block height non-zero to ensure account numbers part of signBytes - suite.ctx = suite.ctx.WithBlockHeight(1) + ctx = ctx.WithBlockHeight(1) + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.SetPubKeyMiddleware(s.app.AccountKeeper), + middleware.SigVerificationMiddleware( + s.app.AccountKeeper, + s.clientCtx.TxConfig.SignModeHandler(), + ), + ) // keys and addresses priv1, _, addr1 := testdata.KeyTestPubAddr() @@ -135,19 +145,15 @@ func (suite *AnteTestSuite) TestSigVerification() { msgs := make([]sdk.Msg, len(addrs)) // set accounts and create msg for each address for i, addr := range addrs { - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) - suite.Require().NoError(acc.SetAccountNumber(uint64(i))) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr) + s.Require().NoError(acc.SetAccountNumber(uint64(i))) + s.app.AccountKeeper.SetAccount(ctx, acc) msgs[i] = testdata.NewTestMsg(addr) } feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() - spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) - svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler()) - antehandler := sdk.ChainAnteDecorators(spkd, svd) - type testCase struct { name string privs []cryptotypes.PrivKey @@ -166,21 +172,25 @@ func (suite *AnteTestSuite) TestSigVerification() { {"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false}, } for i, tc := range testCases { - suite.ctx = suite.ctx.WithIsReCheckTx(tc.recheck) - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test + ctx = ctx.WithIsReCheckTx(tc.recheck) + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) - tx, err := suite.CreateTestTx(tc.privs, tc.accNums, tc.accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + testTx, _, err := s.createTestTx(txBuilder, tc.privs, tc.accNums, tc.accSeqs, ctx.ChainID()) + s.Require().NoError(err) - _, err = antehandler(suite.ctx, tx, false) + if tc.recheck { + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{Type: abci.CheckTxType_Recheck}) + } else { + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{}) + } if tc.shouldErr { - suite.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name) + s.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name) } else { - suite.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) + s.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) } } } @@ -191,35 +201,23 @@ func (suite *AnteTestSuite) TestSigVerification() { // this, since it'll be handled by the test matrix. // In the meantime, we want to make double-sure amino compatibility works. // ref: https://github.com/cosmos/cosmos-sdk/issues/7229 -func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() { - suite.app, suite.ctx = createTestApp(suite.T(), true) - suite.ctx = suite.ctx.WithBlockHeight(1) +func (s *MWTestSuite) TestSigVerification_ExplicitAmino() { + ctx := s.SetupTest(true) + ctx = ctx.WithBlockHeight(1) // Set up TxConfig. - aminoCdc := codec.NewLegacyAmino() + aminoCdc := legacy.Cdc + aminoCdc.RegisterInterface((*sdk.Msg)(nil), nil) + aminoCdc.RegisterConcrete(&testdata.TestMsg{}, "testdata.TestMsg", nil) + // We're using TestMsg amino encoding in some tests, so register it here. txConfig := legacytx.StdTxConfig{Cdc: aminoCdc} - suite.clientCtx = client.Context{}. + s.clientCtx = client.Context{}. WithTxConfig(txConfig) - anteHandler, err := ante.NewAnteHandler( - ante.HandlerOptions{ - AccountKeeper: suite.app.AccountKeeper, - BankKeeper: suite.app.BankKeeper, - FeegrantKeeper: suite.app.FeeGrantKeeper, - SignModeHandler: txConfig.SignModeHandler(), - SigGasConsumer: ante.DefaultSigVerificationGasConsumer, - }, - ) - - suite.Require().NoError(err) - suite.anteHandler = anteHandler - - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() - // make block height non-zero to ensure account numbers part of signBytes - suite.ctx = suite.ctx.WithBlockHeight(1) + ctx = ctx.WithBlockHeight(1) // keys and addresses priv1, _, addr1 := testdata.KeyTestPubAddr() @@ -231,18 +229,23 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() { msgs := make([]sdk.Msg, len(addrs)) // set accounts and create msg for each address for i, addr := range addrs { - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) - suite.Require().NoError(acc.SetAccountNumber(uint64(i))) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr) + s.Require().NoError(acc.SetAccountNumber(uint64(i))) + s.app.AccountKeeper.SetAccount(ctx, acc) msgs[i] = testdata.NewTestMsg(addr) } feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() - spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) - svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler()) - antehandler := sdk.ChainAnteDecorators(spkd, svd) + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.SetPubKeyMiddleware(s.app.AccountKeeper), + middleware.SigVerificationMiddleware( + s.app.AccountKeeper, + s.clientCtx.TxConfig.SignModeHandler(), + ), + ) type testCase struct { name string @@ -252,6 +255,7 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() { recheck bool shouldErr bool } + testCases := []testCase{ {"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, false, true}, {"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0}, false, true}, @@ -261,27 +265,32 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() { {"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0}, false, false}, {"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false}, } + for i, tc := range testCases { - suite.ctx = suite.ctx.WithIsReCheckTx(tc.recheck) - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test + ctx = ctx.WithIsReCheckTx(tc.recheck) + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) - tx, err := suite.CreateTestTx(tc.privs, tc.accNums, tc.accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + testTx, _, err := s.createTestTx(txBuilder, tc.privs, tc.accNums, tc.accSeqs, ctx.ChainID()) + s.Require().NoError(err) - _, err = antehandler(suite.ctx, tx, false) + if tc.recheck { + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{Type: abci.CheckTxType_Recheck}) + } else { + _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{}) + } if tc.shouldErr { - suite.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name) + s.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name) } else { - suite.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) + s.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) } } } -func (suite *AnteTestSuite) TestSigIntegration() { +func (s *MWTestSuite) TestSigIntegration() { // generate private keys privs := []cryptotypes.PrivKey{ secp256k1.GenPrivKey(), @@ -291,23 +300,23 @@ func (suite *AnteTestSuite) TestSigIntegration() { params := types.DefaultParams() initialSigCost := params.SigVerifyCostSecp256k1 - initialCost, err := suite.runSigDecorators(params, false, privs...) - suite.Require().Nil(err) + initialCost, err := s.runSigMiddlewares(params, false, privs...) + s.Require().Nil(err) params.SigVerifyCostSecp256k1 *= 2 - doubleCost, err := suite.runSigDecorators(params, false, privs...) - suite.Require().Nil(err) + doubleCost, err := s.runSigMiddlewares(params, false, privs...) + s.Require().Nil(err) - suite.Require().Equal(initialSigCost*uint64(len(privs)), doubleCost-initialCost) + s.Require().Equal(initialSigCost*uint64(len(privs)), doubleCost-initialCost) } -func (suite *AnteTestSuite) runSigDecorators(params types.Params, _ bool, privs ...cryptotypes.PrivKey) (sdk.Gas, error) { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() +func (s *MWTestSuite) runSigMiddlewares(params types.Params, _ bool, privs ...cryptotypes.PrivKey) (sdk.Gas, error) { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Make block-height non-zero to include accNum in SignBytes - suite.ctx = suite.ctx.WithBlockHeight(1) - suite.app.AccountKeeper.SetParams(suite.ctx, params) + ctx = ctx.WithBlockHeight(1) + s.app.AccountKeeper.SetParams(ctx, params) msgs := make([]sdk.Msg, len(privs)) accNums := make([]uint64, len(privs)) @@ -315,76 +324,89 @@ func (suite *AnteTestSuite) runSigDecorators(params types.Params, _ bool, privs // set accounts and create msg for each address for i, priv := range privs { addr := sdk.AccAddress(priv.PubKey().Address()) - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) - suite.Require().NoError(acc.SetAccountNumber(uint64(i))) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr) + s.Require().NoError(acc.SetAccountNumber(uint64(i))) + s.app.AccountKeeper.SetAccount(ctx, acc) msgs[i] = testdata.NewTestMsg(addr) accNums[i] = uint64(i) accSeqs[i] = uint64(0) } - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) - - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) - - spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) - svgc := ante.NewSigGasConsumeDecorator(suite.app.AccountKeeper, ante.DefaultSigVerificationGasConsumer) - svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler()) - antehandler := sdk.ChainAnteDecorators(spkd, svgc, svd) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) + + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.SetPubKeyMiddleware(s.app.AccountKeeper), + middleware.SigGasConsumeMiddleware(s.app.AccountKeeper, middleware.DefaultSigVerificationGasConsumer), + middleware.SigVerificationMiddleware( + s.app.AccountKeeper, + s.clientCtx.TxConfig.SignModeHandler(), + ), + ) - // Determine gas consumption of antehandler with default params - before := suite.ctx.GasMeter().GasConsumed() - ctx, err := antehandler(suite.ctx, tx, false) + // Determine gas consumption of txhandler with default params + before := ctx.GasMeter().GasConsumed() + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestDeliverTx{}) after := ctx.GasMeter().GasConsumed() return after - before, err } -func (suite *AnteTestSuite) TestIncrementSequenceDecorator() { - suite.SetupTest(true) // setup - suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() +func (s *MWTestSuite) TestIncrementSequenceMiddleware() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() priv, _, addr := testdata.KeyTestPubAddr() - acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) - suite.Require().NoError(acc.SetAccountNumber(uint64(50))) - suite.app.AccountKeeper.SetAccount(suite.ctx, acc) + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr) + s.Require().NoError(acc.SetAccountNumber(uint64(50))) + s.app.AccountKeeper.SetAccount(ctx, acc) msgs := []sdk.Msg{testdata.NewTestMsg(addr)} - suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) + s.Require().NoError(txBuilder.SetMsgs(msgs...)) privs := []cryptotypes.PrivKey{priv} - accNums := []uint64{suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetAccountNumber()} - accSeqs := []uint64{suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetSequence()} + accNums := []uint64{s.app.AccountKeeper.GetAccount(ctx, addr).GetAccountNumber()} + accSeqs := []uint64{s.app.AccountKeeper.GetAccount(ctx, addr).GetSequence()} feeAmount := testdata.NewTestFeeAmount() gasLimit := testdata.NewTestGasLimit() - suite.txBuilder.SetFeeAmount(feeAmount) - suite.txBuilder.SetGasLimit(gasLimit) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) - tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) - suite.Require().NoError(err) + testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) + s.Require().NoError(err) - isd := ante.NewIncrementSequenceDecorator(suite.app.AccountKeeper) - antehandler := sdk.ChainAnteDecorators(isd) + txHandler := middleware.ComposeMiddlewares( + noopTxHandler{}, + middleware.IncrementSequenceMiddleware(s.app.AccountKeeper), + ) testCases := []struct { ctx sdk.Context simulate bool expectedSeq uint64 }{ - {suite.ctx.WithIsReCheckTx(true), false, 1}, - {suite.ctx.WithIsCheckTx(true).WithIsReCheckTx(false), false, 2}, - {suite.ctx.WithIsReCheckTx(true), false, 3}, - {suite.ctx.WithIsReCheckTx(true), false, 4}, - {suite.ctx.WithIsReCheckTx(true), true, 5}, + {ctx.WithIsReCheckTx(true), false, 1}, + {ctx.WithIsCheckTx(true).WithIsReCheckTx(false), false, 2}, + {ctx.WithIsReCheckTx(true), false, 3}, + {ctx.WithIsReCheckTx(true), false, 4}, + {ctx.WithIsReCheckTx(true), true, 5}, } for i, tc := range testCases { - _, err := antehandler(tc.ctx, tx, tc.simulate) - suite.Require().NoError(err, "unexpected error; tc #%d, %v", i, tc) - suite.Require().Equal(tc.expectedSeq, suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetSequence()) + var err error + if tc.simulate { + _, err = txHandler.SimulateTx(sdk.WrapSDKContext(tc.ctx), testTx, tx.RequestSimulateTx{}) + } else { + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(tc.ctx), testTx, abci.RequestDeliverTx{}) + } + + s.Require().NoError(err, "unexpected error; tc #%d, %v", i, tc) + s.Require().Equal(tc.expectedSeq, s.app.AccountKeeper.GetAccount(ctx, addr).GetSequence()) } } diff --git a/x/auth/middleware/testutil_test.go b/x/auth/middleware/testutil_test.go index a2ba1c3cf164..11cfaf72c37e 100644 --- a/x/auth/middleware/testutil_test.go +++ b/x/auth/middleware/testutil_test.go @@ -1,18 +1,24 @@ package middleware_test import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/suite" + "github.com/tendermint/tendermint/abci/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/tx" + "github.com/cosmos/cosmos-sdk/codec" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" "github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/testutil/testdata" sdk "github.com/cosmos/cosmos-sdk/types" + txtypes "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" xauthsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" @@ -30,12 +36,13 @@ type MWTestSuite struct { app *simapp.SimApp clientCtx client.Context + txHandler txtypes.Handler } // returns context and app with params set on account keeper func createTestApp(t *testing.T, isCheckTx bool) (*simapp.SimApp, sdk.Context) { app := simapp.Setup(t, isCheckTx) - ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{}) + ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{}).WithBlockGasMeter(sdk.NewInfiniteGasMeter()) app.AccountKeeper.SetParams(ctx, authtypes.DefaultParams()) return app, ctx @@ -54,14 +61,35 @@ func (s *MWTestSuite) SetupTest(isCheckTx bool) sdk.Context { testdata.RegisterInterfaces(encodingConfig.InterfaceRegistry) s.clientCtx = client.Context{}. - WithTxConfig(encodingConfig.TxConfig) + WithTxConfig(encodingConfig.TxConfig). + WithInterfaceRegistry(encodingConfig.InterfaceRegistry). + WithCodec(codec.NewAminoCodec(encodingConfig.Amino)) + + // We don't use simapp's own txHandler. For more flexibility (i.e. around + // using testdata), we create own own txHandler for this test suite. + msr := middleware.NewMsgServiceRouter(encodingConfig.InterfaceRegistry) + testdata.RegisterMsgServer(msr, testdata.MsgServerImpl{}) + legacyRouter := middleware.NewLegacyRouter() + legacyRouter.AddRoute(sdk.NewRoute((&testdata.TestMsg{}).Route(), func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { return &sdk.Result{}, nil })) + txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ + Debug: s.app.Trace(), + MsgServiceRouter: msr, + LegacyRouter: legacyRouter, + AccountKeeper: s.app.AccountKeeper, + BankKeeper: s.app.BankKeeper, + FeegrantKeeper: s.app.FeeGrantKeeper, + SignModeHandler: encodingConfig.TxConfig.SignModeHandler(), + SigGasConsumer: middleware.DefaultSigVerificationGasConsumer, + }) + s.Require().NoError(err) + s.txHandler = txHandler return ctx } -// CreatetestAccounts creates `numAccs` accounts, and return all relevant +// createTestAccounts creates `numAccs` accounts, and return all relevant // information about them including their private keys. -func (s *MWTestSuite) CreatetestAccounts(ctx sdk.Context, numAccs int) []testAccount { +func (s *MWTestSuite) createTestAccounts(ctx sdk.Context, numAccs int) []testAccount { var accounts []testAccount for i := 0; i < numAccs; i++ { @@ -137,6 +165,48 @@ func (s *MWTestSuite) createTestTx(txBuilder client.TxBuilder, privs []cryptotyp return txBuilder.GetTx(), txBytes, nil } +func (s *MWTestSuite) runTestCase(ctx sdk.Context, txBuilder client.TxBuilder, privs []cryptotypes.PrivKey, msgs []sdk.Msg, feeAmount sdk.Coins, gasLimit uint64, accNums, accSeqs []uint64, chainID string, tc TestCase) { + s.Run(fmt.Sprintf("Case %s", tc.desc), func() { + s.Require().NoError(txBuilder.SetMsgs(msgs...)) + txBuilder.SetFeeAmount(feeAmount) + txBuilder.SetGasLimit(gasLimit) + + // Theoretically speaking, middleware unit tests should only test + // middlewares, but here we sometimes also test the tx creation + // process. + tx, _, txErr := s.createTestTx(txBuilder, privs, accNums, accSeqs, chainID) + newCtx, txHandlerErr := s.txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, types.RequestDeliverTx{}) + + if tc.expPass { + s.Require().NoError(txErr) + s.Require().NoError(txHandlerErr) + s.Require().NotNil(newCtx) + } else { + switch { + case txErr != nil: + s.Require().Error(txErr) + s.Require().True(errors.Is(txErr, tc.expErr)) + + case txHandlerErr != nil: + s.Require().Error(txHandlerErr) + s.Require().True(errors.Is(txHandlerErr, tc.expErr)) + + default: + s.Fail("expected one of txErr,txHandlerErr to be an error") + } + } + }) +} + +// TestCase represents a test case used in test tables. +type TestCase struct { + desc string + malleate func() + simulate bool + expPass bool + expErr error +} + func TestMWTestSuite(t *testing.T) { suite.Run(t, new(MWTestSuite)) } diff --git a/x/auth/signing/verify_test.go b/x/auth/signing/verify_test.go index 7a7f015dbef1..8ae55a3891a5 100644 --- a/x/auth/signing/verify_test.go +++ b/x/auth/signing/verify_test.go @@ -13,7 +13,7 @@ import ( "github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/testutil/testdata" sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" "github.com/cosmos/cosmos-sdk/x/auth/signing" "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -42,7 +42,8 @@ func TestVerifySignature(t *testing.T) { app.AccountKeeper.SetAccount(ctx, acc1) balances := sdk.NewCoins(sdk.NewInt64Coin("atom", 200)) require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances)) - acc, err := ante.GetSignerAcc(ctx, app.AccountKeeper, addr) + acc, err := middleware.GetSignerAcc(ctx, app.AccountKeeper, addr) + require.NoError(t, err) require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances)) msgs := []sdk.Msg{testdata.NewTestMsg(addr)} diff --git a/x/auth/tx/builder.go b/x/auth/tx/builder.go index 359c646087a5..6382f5272d1e 100644 --- a/x/auth/tx/builder.go +++ b/x/auth/tx/builder.go @@ -10,7 +10,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" ) @@ -31,10 +31,10 @@ type wrapper struct { } var ( - _ authsigning.Tx = &wrapper{} - _ client.TxBuilder = &wrapper{} - _ ante.HasExtensionOptionsTx = &wrapper{} - _ ExtensionOptionsTxBuilder = &wrapper{} + _ authsigning.Tx = &wrapper{} + _ client.TxBuilder = &wrapper{} + _ middleware.HasExtensionOptionsTx = &wrapper{} + _ ExtensionOptionsTxBuilder = &wrapper{} ) // ExtensionOptionsTxBuilder defines a TxBuilder that can also set extensions. diff --git a/x/feegrant/keeper/keeper.go b/x/feegrant/keeper/keeper.go index fd2614ccf72e..8e6750463217 100644 --- a/x/feegrant/keeper/keeper.go +++ b/x/feegrant/keeper/keeper.go @@ -8,7 +8,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/x/auth/ante" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/feegrant" ) @@ -20,7 +20,7 @@ type Keeper struct { authKeeper feegrant.AccountKeeper } -var _ ante.FeegrantKeeper = &Keeper{} +var _ middleware.FeegrantKeeper = &Keeper{} // NewKeeper creates a fee grant Keeper func NewKeeper(cdc codec.BinaryCodec, storeKey sdk.StoreKey, ak feegrant.AccountKeeper) Keeper {