From 460b947c33a301aaceab02529cd9f89fe4e5211d Mon Sep 17 00:00:00 2001 From: Pino' Surace <95283998+pinosu@users.noreply.github.com> Date: Fri, 20 Dec 2024 18:39:53 +0100 Subject: [PATCH] Fix AcceptListGrpcQuerier concurrency issues (#2065) * Fix AcceptListGrpcQuerier concurrency issues * Fix tests * Fix * Fix --- go.mod | 2 +- .../query_plugin_integration_test.go | 4 +- x/wasm/keeper/query_plugins.go | 17 ++- x/wasm/keeper/query_plugins_test.go | 113 ++++++++++++++++++ 4 files changed, 128 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index a08131a4c7..74e9892faf 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/distribution/reference v0.5.0 github.com/rs/zerolog v1.33.0 github.com/spf13/viper v1.19.0 + golang.org/x/sync v0.10.0 google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 ) @@ -201,7 +202,6 @@ require ( golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.30.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/tests/integration/query_plugin_integration_test.go b/tests/integration/query_plugin_integration_test.go index dea1eb3109..e4f025ec19 100644 --- a/tests/integration/query_plugin_integration_test.go +++ b/tests/integration/query_plugin_integration_test.go @@ -952,8 +952,8 @@ func TestAcceptListStargateQuerier(t *testing.T) { addrs := app.AddTestAddrsIncremental(wasmApp, ctx, 2, sdkmath.NewInt(1_000_000)) accepted := wasmKeeper.AcceptedQueries{ - "/cosmos.auth.v1beta1.Query/Account": &authtypes.QueryAccountResponse{}, - "/no/route/to/this": &authtypes.QueryAccountResponse{}, + "/cosmos.auth.v1beta1.Query/Account": func() proto.Message { return &authtypes.QueryAccountResponse{} }, + "/no/route/to/this": func() proto.Message { return &authtypes.QueryAccountResponse{} }, } marshal := func(pb proto.Message) []byte { diff --git a/x/wasm/keeper/query_plugins.go b/x/wasm/keeper/query_plugins.go index 2bb2246968..8bc0eb943b 100644 --- a/x/wasm/keeper/query_plugins.go +++ b/x/wasm/keeper/query_plugins.go @@ -346,7 +346,7 @@ func RejectGrpcQuerier(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.M // WithQueryPlugins(&QueryPlugins{Grpc: AcceptListGrpcQuerier(acceptList, queryRouter, codec)}) func AcceptListGrpcQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRouter, codec codec.Codec) func(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.Message, error) { return func(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.Message, error) { - protoResponse, accepted := acceptList[request.Path] + protoResponseFn, accepted := acceptList[request.Path] if !accepted { return nil, wasmvmtypes.UnsupportedRequest{Kind: fmt.Sprintf("'%s' path is not allowed from the contract", request.Path)} } @@ -364,6 +364,7 @@ func AcceptListGrpcQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRout return nil, err } + protoResponse := protoResponseFn() // decode the query response into the expected protobuf message err = codec.Unmarshal(res.Value, protoResponse) if err != nil { @@ -381,10 +382,15 @@ func RejectStargateQuerier() func(ctx sdk.Context, request *wasmvmtypes.Stargate } } -// AcceptedQueries define accepted Stargate or gRPC queries as a map with path as key and response type as value. +// AcceptedQueries defines accepted Stargate or gRPC queries as a map where the key is the query path +// and the value is a function returning a proto.Message. +// // For example: -// acceptList["/cosmos.auth.v1beta1.Query/Account"]= &authtypes.QueryAccountResponse{} -type AcceptedQueries map[string]proto.Message +// +// acceptList["/cosmos.auth.v1beta1.Query/Account"] = func() proto.Message { +// return &authtypes.QueryAccountResponse{} +// } +type AcceptedQueries map[string]func() proto.Message // AcceptListStargateQuerier supports a preconfigured set of stargate queries only. // All arguments must be non nil. @@ -396,7 +402,7 @@ type AcceptedQueries map[string]proto.Message // WithQueryPlugins(&QueryPlugins{Stargate: AcceptListStargateQuerier(acceptList, queryRouter, codec)}) func AcceptListStargateQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRouter, codec codec.Codec) func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) { return func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) { - protoResponse, accepted := acceptList[request.Path] + protoResponseFn, accepted := acceptList[request.Path] if !accepted { return nil, wasmvmtypes.UnsupportedRequest{Kind: fmt.Sprintf("'%s' path is not allowed from the contract", request.Path)} } @@ -414,6 +420,7 @@ func AcceptListStargateQuerier(acceptList AcceptedQueries, queryRouter GRPCQuery return nil, err } + protoResponse := protoResponseFn() return ConvertProtoToJSONMarshal(codec, protoResponse, res.Value) } } diff --git a/x/wasm/keeper/query_plugins_test.go b/x/wasm/keeper/query_plugins_test.go index 76dd64aa64..ade98dd60a 100644 --- a/x/wasm/keeper/query_plugins_test.go +++ b/x/wasm/keeper/query_plugins_test.go @@ -3,15 +3,20 @@ package keeper_test import ( "context" "encoding/json" + "fmt" "math" + "sync/atomic" "testing" wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types" + abci "github.com/cometbft/cometbft/abci/types" cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" dbm "github.com/cosmos/cosmos-db" + "github.com/cosmos/gogoproto/proto" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" errorsmod "cosmossdk.io/errors" "cosmossdk.io/log" @@ -20,6 +25,8 @@ import ( storemetrics "cosmossdk.io/store/metrics" storetypes "cosmossdk.io/store/types" + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" @@ -682,3 +689,109 @@ func TestConvertSDKDecCoinToWasmDecCoin(t *testing.T) { }) } } + +var _ keeper.GRPCQueryRouter = mockedQueryRouter{} + +type mockedQueryRouter struct { + codec codec.Codec +} + +func (m mockedQueryRouter) Route(_ string) baseapp.GRPCQueryHandler { + return func(ctx sdk.Context, req *abci.RequestQuery) (*abci.ResponseQuery, error) { + balanceReq := &banktypes.QueryBalanceRequest{} + if err := m.codec.Unmarshal(req.Data, balanceReq); err != nil { + return nil, err + } + + coin := sdk.NewInt64Coin(balanceReq.Denom, 1) + balanceRes := &banktypes.QueryBalanceResponse{ + Balance: &coin, + } + + resValue, err := m.codec.Marshal(balanceRes) + if err != nil { + return nil, err + } + + return &abci.ResponseQuery{ + Value: resValue, + }, nil + } +} + +func TestGRPCQuerier(t *testing.T) { + const ( + denom1 = "denom1" + denom2 = "denom2" + ) + _, keepers := keeper.CreateTestInput(t, false, keeper.AvailableCapabilities) + cdc := keepers.EncodingConfig.Codec + + acceptedQueries := keeper.AcceptedQueries{ + "/bank.Balance": func() proto.Message { return &banktypes.QueryBalanceResponse{} }, + } + + router := mockedQueryRouter{ + codec: cdc, + } + querier := keeper.AcceptListStargateQuerier(acceptedQueries, router, keepers.EncodingConfig.Codec) + + addr := keeper.RandomBech32AccountAddress(t) + + eg := errgroup.Group{} + errorsCount := atomic.Uint64{} + for range 50 { + for _, denom := range []string{denom1, denom2} { + denom := denom // copy + eg.Go(func() error { + queryReq := &banktypes.QueryBalanceRequest{ + Address: addr, + Denom: denom, + } + grpcData, err := cdc.Marshal(queryReq) + if err != nil { + return err + } + + wasmGrpcReq := &wasmvmtypes.StargateQuery{ + Data: grpcData, + Path: "/bank.Balance", + } + + wasmGrpcRes, err := querier(sdk.Context{}, wasmGrpcReq) + if err != nil { + return err + } + + queryRes := &banktypes.QueryBalanceResponse{} + if err := cdc.UnmarshalJSON(wasmGrpcRes, queryRes); err != nil { + return err + } + + expectedCoin := sdk.NewInt64Coin(denom, 1) + if queryRes.Balance == nil { + fmt.Printf( + "Error: expected %s, got nil\n", + expectedCoin.String(), + ) + errorsCount.Add(1) + return nil + } + if queryRes.Balance.String() != expectedCoin.String() { + fmt.Printf( + "Error: expected %s, got %s\n", + expectedCoin.String(), + queryRes.Balance.String(), + ) + errorsCount.Add(1) + return nil + } + + return nil + }) + } + } + + require.NoError(t, eg.Wait()) + require.Zero(t, errorsCount.Load()) +}