Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AcceptListGrpcQuerier concurrency issues #2065

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ require (
github.com/distribution/reference v0.5.0
github.com/rs/zerolog v1.33.0
github.com/spf13/viper v1.19.0
golang.org/x/sync v0.10.0
google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142
)

Expand Down Expand Up @@ -201,7 +202,6 @@ require (
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/query_plugin_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,8 @@ func TestAcceptListStargateQuerier(t *testing.T) {

addrs := app.AddTestAddrsIncremental(wasmApp, ctx, 2, sdkmath.NewInt(1_000_000))
accepted := wasmKeeper.AcceptedQueries{
"/cosmos.auth.v1beta1.Query/Account": &authtypes.QueryAccountResponse{},
"/no/route/to/this": &authtypes.QueryAccountResponse{},
"/cosmos.auth.v1beta1.Query/Account": func() proto.Message { return &authtypes.QueryAccountResponse{} },
"/no/route/to/this": func() proto.Message { return &authtypes.QueryAccountResponse{} },
}

marshal := func(pb proto.Message) []byte {
Expand Down
17 changes: 12 additions & 5 deletions x/wasm/keeper/query_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func RejectGrpcQuerier(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.M
// WithQueryPlugins(&QueryPlugins{Grpc: AcceptListGrpcQuerier(acceptList, queryRouter, codec)})
func AcceptListGrpcQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRouter, codec codec.Codec) func(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.Message, error) {
return func(ctx sdk.Context, request *wasmvmtypes.GrpcQuery) (proto.Message, error) {
protoResponse, accepted := acceptList[request.Path]
protoResponseFn, accepted := acceptList[request.Path]
if !accepted {
return nil, wasmvmtypes.UnsupportedRequest{Kind: fmt.Sprintf("'%s' path is not allowed from the contract", request.Path)}
}
Expand All @@ -364,6 +364,7 @@ func AcceptListGrpcQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRout
return nil, err
}

protoResponse := protoResponseFn()
// decode the query response into the expected protobuf message
err = codec.Unmarshal(res.Value, protoResponse)
if err != nil {
Expand All @@ -381,10 +382,15 @@ func RejectStargateQuerier() func(ctx sdk.Context, request *wasmvmtypes.Stargate
}
}

// AcceptedQueries define accepted Stargate or gRPC queries as a map with path as key and response type as value.
// AcceptedQueries defines accepted Stargate or gRPC queries as a map where the key is the query path
// and the value is a function returning a proto.Message.
//
// For example:
// acceptList["/cosmos.auth.v1beta1.Query/Account"]= &authtypes.QueryAccountResponse{}
type AcceptedQueries map[string]proto.Message
//
// acceptList["/cosmos.auth.v1beta1.Query/Account"] = func() proto.Message {
// return &authtypes.QueryAccountResponse{}
// }
type AcceptedQueries map[string]func() proto.Message

// AcceptListStargateQuerier supports a preconfigured set of stargate queries only.
// All arguments must be non nil.
Expand All @@ -396,7 +402,7 @@ type AcceptedQueries map[string]proto.Message
// WithQueryPlugins(&QueryPlugins{Stargate: AcceptListStargateQuerier(acceptList, queryRouter, codec)})
func AcceptListStargateQuerier(acceptList AcceptedQueries, queryRouter GRPCQueryRouter, codec codec.Codec) func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) {
return func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) {
protoResponse, accepted := acceptList[request.Path]
protoResponseFn, accepted := acceptList[request.Path]
if !accepted {
return nil, wasmvmtypes.UnsupportedRequest{Kind: fmt.Sprintf("'%s' path is not allowed from the contract", request.Path)}
}
Expand All @@ -414,6 +420,7 @@ func AcceptListStargateQuerier(acceptList AcceptedQueries, queryRouter GRPCQuery
return nil, err
}

protoResponse := protoResponseFn()
return ConvertProtoToJSONMarshal(codec, protoResponse, res.Value)
}
}
Expand Down
113 changes: 113 additions & 0 deletions x/wasm/keeper/query_plugins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ package keeper_test
import (
"context"
"encoding/json"
"fmt"
"math"
"sync/atomic"
"testing"

wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types"
abci "github.com/cometbft/cometbft/abci/types"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
dbm "github.com/cosmos/cosmos-db"
"github.com/cosmos/gogoproto/proto"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

errorsmod "cosmossdk.io/errors"
"cosmossdk.io/log"
Expand All @@ -20,6 +25,8 @@ import (
storemetrics "cosmossdk.io/store/metrics"
storetypes "cosmossdk.io/store/types"

"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/query"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
Expand Down Expand Up @@ -682,3 +689,109 @@ func TestConvertSDKDecCoinToWasmDecCoin(t *testing.T) {
})
}
}

var _ keeper.GRPCQueryRouter = mockedQueryRouter{}

type mockedQueryRouter struct {
codec codec.Codec
}

func (m mockedQueryRouter) Route(_ string) baseapp.GRPCQueryHandler {
return func(ctx sdk.Context, req *abci.RequestQuery) (*abci.ResponseQuery, error) {
balanceReq := &banktypes.QueryBalanceRequest{}
if err := m.codec.Unmarshal(req.Data, balanceReq); err != nil {
return nil, err
}

coin := sdk.NewInt64Coin(balanceReq.Denom, 1)
balanceRes := &banktypes.QueryBalanceResponse{
Balance: &coin,
}

resValue, err := m.codec.Marshal(balanceRes)
if err != nil {
return nil, err
}

return &abci.ResponseQuery{
Value: resValue,
}, nil
}
}

func TestGRPCQuerier(t *testing.T) {
const (
denom1 = "denom1"
denom2 = "denom2"
)
_, keepers := keeper.CreateTestInput(t, false, keeper.AvailableCapabilities)
cdc := keepers.EncodingConfig.Codec

acceptedQueries := keeper.AcceptedQueries{
"/bank.Balance": func() proto.Message { return &banktypes.QueryBalanceResponse{} },
}

router := mockedQueryRouter{
codec: cdc,
}
querier := keeper.AcceptListStargateQuerier(acceptedQueries, router, keepers.EncodingConfig.Codec)

addr := keeper.RandomBech32AccountAddress(t)

eg := errgroup.Group{}
errorsCount := atomic.Uint64{}
for range 50 {
for _, denom := range []string{denom1, denom2} {
denom := denom // copy
eg.Go(func() error {
queryReq := &banktypes.QueryBalanceRequest{
Address: addr,
Denom: denom,
}
grpcData, err := cdc.Marshal(queryReq)
if err != nil {
return err
}

wasmGrpcReq := &wasmvmtypes.StargateQuery{
Data: grpcData,
Path: "/bank.Balance",
}

wasmGrpcRes, err := querier(sdk.Context{}, wasmGrpcReq)
if err != nil {
return err
}

queryRes := &banktypes.QueryBalanceResponse{}
if err := cdc.UnmarshalJSON(wasmGrpcRes, queryRes); err != nil {
return err
}

expectedCoin := sdk.NewInt64Coin(denom, 1)
if queryRes.Balance == nil {
fmt.Printf(
"Error: expected %s, got nil\n",
expectedCoin.String(),
)
errorsCount.Add(1)
return nil
}
if queryRes.Balance.String() != expectedCoin.String() {
fmt.Printf(
"Error: expected %s, got %s\n",
expectedCoin.String(),
queryRes.Balance.String(),
)
errorsCount.Add(1)
return nil
}

return nil
})
}
}

require.NoError(t, eg.Wait())
require.Zero(t, errorsCount.Load())
}
Loading