Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client/v2): support definitions of inner messages #22890

Merged
merged 12 commits into from
Dec 18, 2024
27 changes: 19 additions & 8 deletions api/cosmos/benchmark/module/v1/module.pulsar.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions client/v2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* [#20623](https://github.com/cosmos/cosmos-sdk/pull/20623) Extend client/v2 keyring interface with `KeyType` and `KeyInfo`.
* [#22282](https://github.com/cosmos/cosmos-sdk/pull/22282) Added custom broadcast logic.
* [#22775](https://github.com/cosmos/cosmos-sdk/pull/22775) Added interactive autocli prompt functionality, including message field prompting, validation helpers, and default value support.
* [#22890](https://github.com/cosmos/cosmos-sdk/pull/22890) Added support for flattening inner message fields in autocli as positional arguments.

### Improvements

Expand Down
95 changes: 70 additions & 25 deletions client/v2/autocli/flag/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"

cosmos_proto "github.com/cosmos/cosmos-proto"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -162,38 +163,32 @@ func (b *Builder) addMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet,
messageBinder.hasOptional = true
}

field := fields.ByName(protoreflect.Name(arg.ProtoField))
if field == nil {
return nil, fmt.Errorf("can't find field %s on %s", arg.ProtoField, messageType.Descriptor().FullName())
}

_, hasValue, err := b.addFieldFlag(
ctx,
messageBinder.positionalFlagSet,
field,
&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", i)},
namingOptions{},
)
if err != nil {
return nil, err
s := strings.Split(arg.ProtoField, ".")
if len(s) == 1 {
f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(arg.ProtoField), fields)
if err != nil {
return nil, err
}
messageBinder.positionalArgs = append(messageBinder.positionalArgs, f)
} else {
err := b.addFlattenFieldBindingToArgs(ctx, arg.ProtoField, s, messageType, messageBinder)
if err != nil {
return nil, err
}
}

messageBinder.positionalArgs = append(messageBinder.positionalArgs, fieldBinding{
field: field,
hasValue: hasValue,
})
}

totalArgs := len(messageBinder.positionalArgs)
switch {
case messageBinder.hasVarargs:
messageBinder.CobraArgs = cobra.MinimumNArgs(positionalArgsLen - 1)
messageBinder.mandatoryArgUntil = positionalArgsLen - 1
messageBinder.CobraArgs = cobra.MinimumNArgs(totalArgs - 1)
messageBinder.mandatoryArgUntil = totalArgs - 1
case messageBinder.hasOptional:
messageBinder.CobraArgs = cobra.RangeArgs(positionalArgsLen-1, positionalArgsLen)
messageBinder.mandatoryArgUntil = positionalArgsLen - 1
messageBinder.CobraArgs = cobra.RangeArgs(totalArgs-1, totalArgs)
messageBinder.mandatoryArgUntil = totalArgs - 1
default:
messageBinder.CobraArgs = cobra.ExactArgs(positionalArgsLen)
messageBinder.mandatoryArgUntil = positionalArgsLen
messageBinder.CobraArgs = cobra.ExactArgs(totalArgs)
messageBinder.mandatoryArgUntil = totalArgs
}

// validate flag options
Expand Down Expand Up @@ -273,6 +268,56 @@ func (b *Builder) addMessageFlags(ctx *context.Context, flagSet *pflag.FlagSet,
return messageBinder, nil
}

// addFlattenFieldBindingToArgs recursively adds field bindings for nested message fields to the message binder.
// It takes a slice of field names representing the path to the target field, where each element is a field name
// in the nested message structure. For example, ["foo", "bar", "baz"] would bind the "baz" field inside the "bar"
// message which is inside the "foo" message.
func (b *Builder) addFlattenFieldBindingToArgs(ctx *context.Context, path string, s []string, msg protoreflect.MessageType, messageBinder *MessageBinder) error {
fields := msg.Descriptor().Fields()
if len(s) == 1 {
f, err := b.addFieldBindingToArgs(ctx, messageBinder, protoreflect.Name(s[0]), fields)
if err != nil {
return err
}
f.path = path
messageBinder.positionalArgs = append(messageBinder.positionalArgs, f)
return nil
}
fd := fields.ByName(protoreflect.Name(s[0]))
var innerMsg protoreflect.MessageType
if fd.IsList() {
innerMsg = msg.New().Get(fd).List().NewElement().Message().Type()
} else {
innerMsg = msg.New().Get(fd).Message().Type()
}
return b.addFlattenFieldBindingToArgs(ctx, path, s[1:], innerMsg, messageBinder)
}

// addFieldBindingToArgs adds a fieldBinding for a positional argument to the message binder.
// The fieldBinding is appended to the positional arguments list in the message binder.
func (b *Builder) addFieldBindingToArgs(ctx *context.Context, messageBinder *MessageBinder, name protoreflect.Name, fields protoreflect.FieldDescriptors) (fieldBinding, error) {
field := fields.ByName(name)
if field == nil {
return fieldBinding{}, fmt.Errorf("can't find field %s", name) // TODO: it will improve error if msg.FullName() was included.`
}

_, hasValue, err := b.addFieldFlag(
ctx,
messageBinder.positionalFlagSet,
field,
&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", len(messageBinder.positionalArgs))},
namingOptions{},
)
if err != nil {
return fieldBinding{}, err
}

return fieldBinding{
field: field,
hasValue: hasValue,
}, nil
}

// bindPageRequest create a flag for pagination
func (b *Builder) bindPageRequest(ctx *context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) {
return b.addMessageFlags(
Expand Down
47 changes: 45 additions & 2 deletions client/v2/autocli/flag/messager_binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flag

import (
"fmt"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -65,10 +66,18 @@ func (m MessageBinder) Bind(msg protoreflect.Message, positionalArgs []string) e
}
}

msgName := msg.Descriptor().Name()
// bind positional arg values to the message
for _, arg := range m.positionalArgs {
if err := arg.bind(msg); err != nil {
return err
if msgName == arg.field.Parent().Name() {
if err := arg.bind(msg); err != nil {
return err
}
} else {
s := strings.Split(arg.path, ".")
if err := m.bindNestedField(msg, arg, s); err != nil {
return err
}
}
}

Expand All @@ -82,6 +91,39 @@ func (m MessageBinder) Bind(msg protoreflect.Message, positionalArgs []string) e
return nil
}

// bindNestedField binds a field value to a nested message field. It handles cases where the field
// belongs to a nested message type by recursively traversing the message structure.
func (m *MessageBinder) bindNestedField(msg protoreflect.Message, arg fieldBinding, path []string) error {
if len(path) == 1 {
return arg.bind(msg)
}

name := protoreflect.Name(path[0])
fd := msg.Descriptor().Fields().ByName(name)
if fd == nil {
return fmt.Errorf("field %q not found", path[0])
}

var innerMsg protoreflect.Message
if fd.IsList() {
if msg.Get(fd).List().Len() == 0 {
l := msg.Mutable(fd).List()
elem := l.NewElement().Message().New()
l.Append(protoreflect.ValueOfMessage(elem))
msg.Set(msg.Descriptor().Fields().ByName(name), protoreflect.ValueOfList(l))
}
innerMsg = msg.Get(fd).List().Get(0).Message()
} else {
innerMsgValue := msg.Get(fd)
if !innerMsgValue.Message().IsValid() {
msg.Set(msg.Descriptor().Fields().ByName(name), protoreflect.ValueOfMessage(innerMsgValue.Message().New()))
}
innerMsg = msg.Get(msg.Descriptor().Fields().ByName(name)).Message()
}

return m.bindNestedField(innerMsg, arg, path[1:])
}

// Get calls BuildMessage and wraps the result in a protoreflect.Value.
func (m MessageBinder) Get(protoreflect.Value) (protoreflect.Value, error) {
msg, err := m.BuildMessage(nil)
Expand All @@ -91,6 +133,7 @@ func (m MessageBinder) Get(protoreflect.Value) (protoreflect.Value, error) {
type fieldBinding struct {
hasValue HasValue
field protoreflect.FieldDescriptor
path string
}

func (f fieldBinding) bind(msg protoreflect.Message) error {
Expand Down
27 changes: 27 additions & 0 deletions client/v2/autocli/msg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,33 @@ func TestMsg(t *testing.T) {
assertNormalizedJSONEqual(t, out.Bytes(), goldenLoad(t, "msg-output.golden"))
}

func TestMsgWithFlattenFields(t *testing.T) {
fixture := initFixture(t)

out, err := runCmd(fixture, buildCustomModuleMsgCommand(&autocliv1.ServiceCommandDescriptor{
Service: bankv1beta1.Msg_ServiceDesc.ServiceName,
RpcCommandOptions: []*autocliv1.RpcCommandOptions{
{
RpcMethod: "UpdateParams",
PositionalArgs: []*autocliv1.PositionalArgDescriptor{
{ProtoField: "authority"},
{ProtoField: "params.send_enabled.denom"},
{ProtoField: "params.send_enabled.enabled"},
{ProtoField: "params.default_send_enabled"},
},
},
},
EnhanceCustomCommand: true,
}), "update-params",
"cosmos1y74p8wyy4enfhfn342njve6cjmj5c8dtl6emdk", "stake", "true", "true",
"--generate-only",
"--output", "json",
"--chain-id", fixture.chainID,
)
assert.NilError(t, err)
assertNormalizedJSONEqual(t, out.Bytes(), goldenLoad(t, "flatten-output.golden"))
}

func goldenLoad(t *testing.T, filename string) []byte {
t.Helper()
content, err := os.ReadFile(filepath.Join("testdata", filename))
Expand Down
1 change: 1 addition & 0 deletions client/v2/autocli/testdata/flatten-output.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"body":{"messages":[{"@type":"/cosmos.bank.v1beta1.MsgUpdateParams","authority":"cosmos1y74p8wyy4enfhfn342njve6cjmj5c8dtl6emdk","params":{"send_enabled":[{"denom":"stake","enabled":true}],"default_send_enabled":true}}],"memo":"","timeout_height":"0","unordered":false,"timeout_timestamp":"1970-01-01T00:00:00Z","extension_options":[],"non_critical_extension_options":[]},"auth_info":{"signer_infos":[],"fee":{"amount":[],"gas_limit":"200000","payer":"","granter":""},"tip":null},"signatures":[]}
3 changes: 1 addition & 2 deletions proto/cosmos/autocli/v1/options.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package cosmos.autocli.v1;

import "cosmos_proto/cosmos.proto";

option go_package = "cosmossdk.io/api/cosmos/base/cli/v1;cliv1";

// ModuleOptions describes the CLI options for a Cosmos SDK module.
Expand All @@ -16,7 +17,6 @@ message ModuleOptions {

// ServiceCommandDescriptor describes a CLI command based on a protobuf service.
message ServiceCommandDescriptor {

// service is the fully qualified name of the protobuf service to build
// the command from. It can be left empty if sub_commands are used instead
// which may be the case if a module provides multiple tx and/or query services.
Expand Down Expand Up @@ -103,7 +103,6 @@ message RpcCommandOptions {
// kebab-case name of the field. Fields can be turned into positional arguments
// instead by using RpcCommandOptions.positional_args.
message FlagOptions {

// name is an alternate name to use for the field flag.
string name = 1;

Expand Down
Loading