Skip to content

Commit

Permalink
testing/protocmp: add Message.Unwrap
Browse files Browse the repository at this point in the history
The Unwrap method returns the original concrete message value.
In theory this allows users to mutate the original message when the
cmp documentation says that all options should be mutation free.
If users want to disregard this documented restriction, they can
already do so in a number of different ways.

Updates #1347

Change-Id: I65225681ab5dbce0763a140fd02666a4ab542a04
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/340489
Trust: Joe Tsai <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
  • Loading branch information
dsnet authored and neild committed Aug 6, 2021
1 parent 05be61f commit 5aec41b
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 31 deletions.
4 changes: 2 additions & 2 deletions testing/protocmp/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoref
}

// Range over populated extension fields.
for _, xd := range m[messageTypeKey].(messageType).xds {
for _, xd := range m[messageTypeKey].(messageMeta).xds {
if m.Has(xd) && !f(xd, m.Get(xd)) {
return
}
Expand All @@ -91,7 +91,7 @@ func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value
return protoreflect.ValueOfMap(reflectMap{})
case fd.Message() != nil:
return protoreflect.ValueOfMessage(reflectMessage{
messageTypeKey: messageType{md: m.Descriptor()},
messageTypeKey: messageMeta{md: fd.Message()},
})
default:
return fd.Default()
Expand Down
12 changes: 6 additions & 6 deletions testing/protocmp/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ func (f *nameFilters) filterFieldName(m Message, k string) bool {
return true // treat missing fields as already filtered
}
var fd protoreflect.FieldDescriptor
switch mt := m[messageTypeKey].(messageType); {
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mt.md.Fields().ByTextName(k)
fd = mm.md.Fields().ByTextName(k)
default:
fd = mt.xds[k]
fd = mm.xds[k]
}
if fd != nil {
return f.names[fd.FullName()]
Expand Down Expand Up @@ -376,11 +376,11 @@ func isDefaultScalar(m Message, k string) bool {
}

var fd protoreflect.FieldDescriptor
switch mt := m[messageTypeKey].(messageType); {
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mt.md.Fields().ByTextName(k)
fd = mm.md.Fields().ByTextName(k)
default:
fd = mt.xds[k]
fd = mm.xds[k]
}
if fd == nil || !fd.Default().IsValid() {
return false
Expand Down
31 changes: 23 additions & 8 deletions testing/protocmp/xform.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,28 @@ func (e Enum) String() string {
}

const (
messageTypeKey = "@type"
// messageTypeKey indicates the protobuf message type.
// The value type is always messageMeta.
// From the public API, it presents itself as only the type, but the
// underlying data structure holds arbitrary metadata about the message.
messageTypeKey = "@type"

// messageInvalidKey indicates that the message is invalid.
// The value is always the boolean "true".
messageInvalidKey = "@invalid"
)

type messageType struct {
type messageMeta struct {
m proto.Message
md protoreflect.MessageDescriptor
xds map[string]protoreflect.ExtensionDescriptor
}

func (t messageType) String() string {
func (t messageMeta) String() string {
return string(t.md.FullName())
}

func (t1 messageType) Equal(t2 messageType) bool {
func (t1 messageMeta) Equal(t2 messageMeta) bool {
return t1.md.FullName() == t2.md.FullName()
}

Expand Down Expand Up @@ -109,11 +117,18 @@ func (t1 messageType) Equal(t2 messageType) bool {
// Message values must not be created by or mutated by users.
type Message map[string]interface{}

// Unwrap returns the original message value.
// It returns nil if this Message was not constructed from another message.
func (m Message) Unwrap() proto.Message {
mm, _ := m[messageTypeKey].(messageMeta)
return mm.m
}

// Descriptor return the message descriptor.
// It returns nil for a zero Message value.
func (m Message) Descriptor() protoreflect.MessageDescriptor {
mt, _ := m[messageTypeKey].(messageType)
return mt.md
mm, _ := m[messageTypeKey].(messageMeta)
return mm.md
}

// ProtoReflect returns a reflective view of m.
Expand Down Expand Up @@ -201,7 +216,7 @@ func Transform(...option) cmp.Option {
case m == nil:
return nil
case !m.IsValid():
return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true}
return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
default:
return transformMessage(m)
}
Expand All @@ -218,7 +233,7 @@ func isMessageType(t reflect.Type) bool {

func transformMessage(m protoreflect.Message) Message {
mx := Message{}
mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}

// Handle known and extension fields.
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
Expand Down
33 changes: 18 additions & 15 deletions testing/protocmp/xform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestTransform(t *testing.T) {
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{A: proto.Int32(5)},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"optional_bool": bool(false),
"optional_int32": int32(-32),
"optional_int64": int64(-64),
Expand All @@ -51,7 +51,7 @@ func TestTransform(t *testing.T) {
"optional_string": string("string"),
"optional_bytes": []byte("bytes"),
"optional_nested_enum": enumOf(testpb.TestAllTypes_NEG),
"optional_nested_message": Message{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
"optional_nested_message": Message{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
},
}, {
in: &testpb.TestAllTypes{
Expand All @@ -74,7 +74,7 @@ func TestTransform(t *testing.T) {
},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"repeated_bool": []bool{false, true},
"repeated_int32": []int32{32, -32},
"repeated_int64": []int64{64, -64},
Expand All @@ -89,8 +89,8 @@ func TestTransform(t *testing.T) {
enumOf(testpb.TestAllTypes_BAR),
},
"repeated_nested_message": []Message{
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
},
},
}, {
Expand All @@ -112,7 +112,7 @@ func TestTransform(t *testing.T) {
},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"map_bool_bool": map[bool]bool{true: false},
"map_int32_int32": map[int32]int32{-32: 32},
"map_int64_int64": map[int64]int64{-64: 64},
Expand All @@ -126,7 +126,7 @@ func TestTransform(t *testing.T) {
"k": enumOf(testpb.TestAllTypes_FOO),
},
"map_string_nested_message": map[string]Message{
"k": {messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
"k": {messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
},
},
}, {
Expand All @@ -146,7 +146,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
"[goproto.proto.test.optional_bool]": bool(false),
"[goproto.proto.test.optional_int32]": int32(-32),
"[goproto.proto.test.optional_int64]": int64(-64),
Expand All @@ -157,7 +157,7 @@ func TestTransform(t *testing.T) {
"[goproto.proto.test.optional_string]": string("string"),
"[goproto.proto.test.optional_bytes]": []byte("bytes"),
"[goproto.proto.test.optional_nested_enum]": enumOf(testpb.TestAllTypes_NEG),
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
},
}, {
in: func() proto.Message {
Expand All @@ -182,7 +182,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
"[goproto.proto.test.repeated_bool]": []bool{false, true},
"[goproto.proto.test.repeated_int32]": []int32{32, -32},
"[goproto.proto.test.repeated_int64]": []int64{64, -64},
Expand All @@ -197,8 +197,8 @@ func TestTransform(t *testing.T) {
enumOf(testpb.TestAllTypes_BAR),
},
"[goproto.proto.test.repeated_nested_message]": []Message{
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
},
},
}, {
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"50000": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50000, Type: protopack.VarintType}, protopack.Uvarint(100)}.Marshal()),
"50001": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50001, Type: protopack.Fixed32Type}, protopack.Uint32(200)}.Marshal()),
"50002": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50002, Type: protopack.Fixed64Type}, protopack.Uint64(300)}.Marshal()),
Expand Down Expand Up @@ -258,6 +258,9 @@ func TestTransform(t *testing.T) {
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
}
if got.Unwrap() != tt.in {
t.Errorf("got.Unwrap() = %p, want %p", got.Unwrap(), tt.in)
}
})
}
}
Expand All @@ -266,6 +269,6 @@ func enumOf(e protoreflect.Enum) Enum {
return Enum{e.Number(), e.Descriptor()}
}

func messageTypeOf(m protoreflect.ProtoMessage) messageType {
return messageType{md: m.ProtoReflect().Descriptor()}
func messageMetaOf(m protoreflect.ProtoMessage) messageMeta {
return messageMeta{m: m, md: m.ProtoReflect().Descriptor()}
}

0 comments on commit 5aec41b

Please sign in to comment.