Skip to content

Commit

Permalink
Support mapping of protobuf lists. (#915)
Browse files Browse the repository at this point in the history
* Support nested containers in gNMI->Protobuf unmarshalling.

 * (M) protomap/{proto,proto_test.go}
   - Support nested messages when unmarshalling protobufs, previously
     such messages did not have their contents mapped.
 * (M) protomap/integration_tests/integration_test.go
   - Add a testcase for gRIBI's real protobufs to ensure that
     unmarshalling is covered.
  • Loading branch information
robshakir authored Sep 25, 2023
1 parent 5cf59cd commit 2224a32
Show file tree
Hide file tree
Showing 5 changed files with 993 additions and 508 deletions.
59 changes: 46 additions & 13 deletions protomap/integration_tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ func TestGRIBIAFTToStruct(t *testing.T) {
}, {
desc: "map next-hop-group",
inPaths: map[*gpb.Path]interface{}{
mustPath("next-hops/next-hop[index=1]/index"): mustValue(t, 1),
mustPath("next-hops/next-hop[index=1]/state/index"): mustValue(t, 1),
mustPath("next-hops/next-hop[index=1]/state/weight"): mustValue(t, 1),
mustPath("next-hops/next-hop[index=1]/index"): mustValue(t, uint64(1)),
mustPath("next-hops/next-hop[index=1]/state/index"): mustValue(t, uint64(1)),
mustPath("next-hops/next-hop[index=1]/state/weight"): mustValue(t, uint64(1)),
},
inProto: &gribi_aft.Afts_NextHopGroup{},
inPrefix: &gpb.Path{
Expand All @@ -172,15 +172,45 @@ func TestGRIBIAFTToStruct(t *testing.T) {
}},
},
wantProto: &gribi_aft.Afts_NextHopGroup{
// Currently this error is ignored for backwards compatibility with other
// messages where there are repeated fields that are not covered.
/* NextHop: []*gribi_aft.Afts_NextHopGroup_NextHopKey{{
Index: 1,
NextHop: &gribi_aft.Afts_NextHopGroup_NextHop{
Weight: &wpb.UintValue{Value: 1},
},
}},
*/
NextHop: []*gribi_aft.Afts_NextHopGroup_NextHopKey{{
Index: 1,
NextHop: &gribi_aft.Afts_NextHopGroup_NextHop{
Weight: &wpb.UintValue{Value: 1},
},
}},
},
}, {
desc: "multiple NHGs",
inPaths: map[*gpb.Path]interface{}{
mustPath("next-hops/next-hop[index=1]/index"): mustValue(t, uint64(1)),
mustPath("next-hops/next-hop[index=1]/state/index"): mustValue(t, uint64(1)),
mustPath("next-hops/next-hop[index=1]/state/weight"): mustValue(t, uint64(1)),
mustPath("next-hops/next-hop[index=2]/index"): mustValue(t, uint64(2)),
mustPath("next-hops/next-hop[index=2]/state/index"): mustValue(t, uint64(2)),
mustPath("next-hops/next-hop[index=2]/state/weight"): mustValue(t, uint64(2)),
},
inProto: &gribi_aft.Afts_NextHopGroup{},
inPrefix: &gpb.Path{
Elem: []*gpb.PathElem{{
Name: "afts",
}, {
Name: "next-hop-groups",
}, {
Name: "next-hop-group",
}},
},
wantProto: &gribi_aft.Afts_NextHopGroup{
NextHop: []*gribi_aft.Afts_NextHopGroup_NextHopKey{{
Index: 1,
NextHop: &gribi_aft.Afts_NextHopGroup_NextHop{
Weight: &wpb.UintValue{Value: 1},
},
}, {
Index: 2,
NextHop: &gribi_aft.Afts_NextHopGroup_NextHop{
Weight: &wpb.UintValue{Value: 2},
},
}},
},
}, {
desc: "embedded field in next-hop",
Expand Down Expand Up @@ -208,7 +238,10 @@ func TestGRIBIAFTToStruct(t *testing.T) {
return
}

if diff := cmp.Diff(tt.inProto, tt.wantProto, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(tt.inProto, tt.wantProto,
protocmp.Transform(),
protocmp.SortRepeatedFields(&gribi_aft.Afts_NextHopGroup{}, "next_hop"),
); diff != "" {
t.Fatalf("did not get expected protobuf, diff(-got,+want):\n%s", diff)
}
})
Expand Down
250 changes: 211 additions & 39 deletions protomap/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package protomap
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"

"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -530,6 +532,7 @@ func ProtoFromPaths(p proto.Message, vals map[*gpb.Path]interface{}, opt ...Unma
if err != nil {
return fmt.Errorf("invalid value prefix supplied, %v", err)
}
valPrefix = schemaPath(valPrefix)

protoPrefix, err := hasProtoMsgPrefix(opt)
if err != nil {
Expand All @@ -539,54 +542,59 @@ func ProtoFromPaths(p proto.Message, vals map[*gpb.Path]interface{}, opt ...Unma
return protoFromPathsInternal(p, vals, valPrefix, protoPrefix, hasIgnoreExtraPaths(opt))
}

func protoFromPathsInternal(p proto.Message, vals map[*gpb.Path]any, valPrefix, protoPrefix *gpb.Path, ignoreExtras bool) error {
schemaPath := func(p *gpb.Path) *gpb.Path {
np := proto.Clone(p).(*gpb.Path)
for _, e := range np.Elem {
e.Key = nil
}
return np
// schemaPath converts the path p into a schema path by removing all of the keys within the path.
func schemaPath(p *gpb.Path) *gpb.Path {
np := proto.Clone(p).(*gpb.Path)
for _, e := range np.Elem {
e.Key = nil
}
return np
}

findChildren := func(vals map[*gpb.Path]any, valPrefix *gpb.Path, protoPrefix *gpb.Path, directOnly, mustBeChildren bool) (map[*gpb.Path]any, error) {
// directCh is a map between the absolute schema path for a particular value, and
// the value specified.
directCh := map[*gpb.Path]interface{}{}
for p, v := range vals {
absPath := &gpb.Path{
Elem: append(append([]*gpb.PathElem{}, schemaPath(valPrefix).Elem...), p.Elem...),
}
// findChildren returns the entries from the vals map that correspond to children of the specified protoPrefix path.
// The valPrefix path is prepended to the paths within the vals map to make these values absolute. If the directOnly bool
// is set to true, then only direct children (not subsequent descendents) are returned. If mustBeChildren is set to true
// then an error is returned if there are any values within the vals map that are not children.
func findChildren(vals map[*gpb.Path]any, valPrefix *gpb.Path, protoPrefix *gpb.Path, directOnly, mustBeChildren bool) (map[*gpb.Path]any, error) {
// directCh is a map between the absolute schema path for a particular value, and
// the value specified.
directCh := map[*gpb.Path]interface{}{}
for p, v := range vals {
absPath := &gpb.Path{
Elem: append(append([]*gpb.PathElem{}, valPrefix.Elem...), p.Elem...),
}

if !util.PathMatchesPathElemPrefix(absPath, protoPrefix) {
if mustBeChildren {
return nil, fmt.Errorf("invalid path provided, absolute paths must be used, %s does not have prefix %s", absPath, protoPrefix)
}
continue
if !util.PathMatchesPathElemPrefix(absPath, protoPrefix) {
if mustBeChildren {
return nil, fmt.Errorf("invalid path provided, absolute paths must be used, %s does not have prefix %s", absPath, protoPrefix)
}
continue
}

// make the path absolute, and a schema path.
pp := util.TrimGNMIPathElemPrefix(absPath, protoPrefix)
// make the path absolute, and a schema path.
pp := util.TrimGNMIPathElemPrefix(absPath, protoPrefix)

switch directOnly {
case true:
if len(pp.GetElem()) == 1 {
switch directOnly {
case true:
if len(pp.GetElem()) == 1 {
directCh[pp] = v
}
// TODO(robjs): it'd be good to have something here that tells us whether we are in
// a compressed schema. Potentially we should add something to the generated protobuf
// as a fileoption that would give us this indication.
if len(pp.Elem) == 2 {
if pp.Elem[len(pp.Elem)-2].Name == "config" || pp.Elem[len(pp.Elem)-2].Name == "state" {
directCh[pp] = v
}
// TODO(robjs): it'd be good to have something here that tells us whether we are in
// a compressed schema. Potentially we should add something to the generated protobuf
// as a fileoption that would give us this indication.
if len(pp.Elem) == 2 {
if pp.Elem[len(pp.Elem)-2].Name == "config" || pp.Elem[len(pp.Elem)-2].Name == "state" {
directCh[pp] = v
}
}
case false:
directCh[pp] = v
}
case false:
directCh[pp] = v
}
return directCh, nil
}
return directCh, nil
}

func protoFromPathsInternal(p proto.Message, vals map[*gpb.Path]any, valPrefix, protoPrefix *gpb.Path, ignoreExtras bool) error {
// It is safe for us to call findChldren setting mustBeChildren to true since we are in one of two cases:
//
// * the first iteration through the function, at which point we expect that vals can only
Expand All @@ -612,11 +620,12 @@ func protoFromPathsInternal(p proto.Message, vals map[*gpb.Path]any, valPrefix,

if len(directCh) != 0 {
for _, ap := range annotatedPath {
if !util.PathMatchesPathElemPrefix(ap, protoPrefix) {
trimmedPrefix := schemaPath(protoPrefix)
if !util.PathMatchesPathElemPrefix(ap, trimmedPrefix) {
rangeErr = fmt.Errorf("annotation %s does not match the supplied prefix %s", ap, protoPrefix)
return false
}
trimmedAP := util.TrimGNMIPathElemPrefix(ap, protoPrefix)
trimmedAP := util.TrimGNMIPathElemPrefix(ap, trimmedPrefix)

// Map the values that we have that a direct children of this message.
for chp, chv := range directCh {
Expand Down Expand Up @@ -657,7 +666,23 @@ func protoFromPathsInternal(p proto.Message, vals map[*gpb.Path]any, valPrefix,
if fd.Kind() == protoreflect.MessageKind {
switch {
case fd.IsList():
// TODO(robjs): Support mapping these fields -- currently we silently drop them for backwards compatibility.
leaflist, leaflistunion, err := annotatedYANGFieldInfo(fd)
if err != nil {
rangeErr = fmt.Errorf("cannot extract field information for %s, %v", fd.FullName(), err)
}
switch {
case leaflist, leaflistunion:
// TODO(robjs): Support these fields, silently dropped for backwards compatibility.
default:
// This is a YANG list field which is a repeated within a protobuf. We need to extract the
// keys from this message and create a list in the entry.
members, err := createListField(p, fd, annotatedPath[0], valPrefix, protoPrefix, vals, ignoreExtras)
if err != nil {
rangeErr = err
return false
}
m.Set(fd, members)
}
case fd.IsMap():
rangeErr = fmt.Errorf("map fields are not supported in mapped protobufs at field %s", fd.FullName())
return false
Expand Down Expand Up @@ -704,6 +729,153 @@ func protoFromPathsInternal(p proto.Message, vals map[*gpb.Path]any, valPrefix,
return nil
}

// createListField creates the entries of the repeated field fd within the protobuf message m, mapping the values within the val map.
// valPrefix specifies the prefix to be applied to the paths within the vals map, protoPrefix specifies the prefix for the protobuf
// message (if it is not the root), and fieldPath specifies the path to the field that is being mapped. ignoreExtras indicates whether
// extra paths that do not exist in the message should be treated as errors.
func createListField(m proto.Message, fd protoreflect.FieldDescriptor, fieldPath, valPrefix, protoPrefix *gpb.Path, vals map[*gpb.Path]any, ignoreExtras bool) (protoreflect.Value, error) {
keys := []map[string]string{}
keyPaths := []*gpb.Path{}

// We need to identify the keys that are within the list, as well as the data tree paths
// that they correspond to. We walk through the supplied values to determine which to process.
for p := range vals {
// Make the paths within the vals map absolute according to the supplied valPrefix.
absPath := &gpb.Path{
Elem: append(append([]*gpb.PathElem{}, valPrefix.Elem...), p.Elem...),
}
// Since the fieldPath is a schema path, then we need to compare just schema paths
// to avoid comparing the keys.
if !util.PathMatchesPathElemPrefix(schemaPath(absPath), schemaPath(fieldPath)) {
continue
}
// The key of the list is in the last element of the absolute path in the values map (the values
// map MUST contain data tree paths, since it is telling us list values to unmarshal).
k := absPath.Elem[len(fieldPath.Elem)-1]
// If the last element doesn't have a key, then we have not correctly found the list.
if len(k.Key) == 0 {
return protoreflect.Value{}, fmt.Errorf("invalid list data field path %s: does not have key values populated", fieldPath)
}

// Find the parts of the path that are not the list -- we assume that this is 2 elements since
// we are in a compressed schema.
// TODO(robjs): Currently, this may report incorrectly in an uncompressed schema, but we don't have
// a signal to indicate this. One needs to be added to the generated protobufs.
keyPath := &gpb.Path{
Elem: append(append([]*gpb.PathElem{}, protoPrefix.Elem...), absPath.Elem[len(protoPrefix.Elem):len(protoPrefix.Elem)+2]...),
}
var alreadySeen bool
for _, ek := range keys {
if reflect.DeepEqual(ek, k.Key) {
alreadySeen = true
break
}
}
if !alreadySeen {
keys = append(keys, k.Key)
keyPaths = append(keyPaths, keyPath)
}
}

le := m.ProtoReflect().NewField(fd).List()
for i, key := range keys {
listElemChildren, err := findChildren(vals, valPrefix, keyPaths[i], false, false)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("logic error, error returned from extracting list member children, %v", err)
}

// We now need to create the "XXXKey" message, and populate the key values, subsequent values are then populated
// into the one protobuf message field.
childMsgEmpty := le.NewElement().Message()
childMsgTarget := le.NewElement().Message()

// Walk through the fields of the XXXKey message that we just created. We use the childMsgEmpty here so that we
// don't change the message whilst iterating which causes us to revisit that field. We set the values in
// the childMsgTarget message.
var retErr error

// Store the key values that we received, to make sure that they are mapped during iteration
// through the protobuf fields.
remainingKeys := map[string]bool{}
for n := range key {
remainingKeys[n] = true
}
unpopRange{childMsgEmpty}.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
// We have one field that is a message in a key message, which is the payload. The remaining fields are the keys.
switch fd.Kind() {
case protoreflect.MessageKind:
m := childMsgTarget.NewField(fd).Message()
// We must ignore extra fields from this point in the recursion, because keys map to fields that
// are not present in the generated protobuf.
if err := protoFromPathsInternal(m.Interface(), listElemChildren, keyPaths[i], keyPaths[i], true); err != nil {
retErr = err
return false
}
childMsgTarget.Set(fd, protoreflect.ValueOfMessage(m))
default:
paths, err := annotatedSchemaPath(fd)
if err != nil {
retErr = err
return false
}

var keyName string
for _, p := range paths {
keyName = p.Elem[len(p.Elem)-1].Name
break
}
if _, ok := key[keyName]; !ok {
retErr = fmt.Errorf("field %s, missing key %s, got keys: %v", fd.FullName(), keyName, key)
return false
}
remainingKeys[keyName] = false
pv, err := listKeyAsProtoValue(fd, key[keyName])
if err != nil {
retErr = fmt.Errorf("field %s, %v", fd.FullName(), err)
return false
}
childMsgTarget.Set(fd, pv)
}

return true
})
if retErr != nil {
return protoreflect.Value{}, fmt.Errorf("field %s, %v", fd.FullName(), retErr)
}

unmappedKeys := []string{}
for k, v := range remainingKeys {
if v {
unmappedKeys = append(unmappedKeys, k)
}
}
if len(unmappedKeys) != 0 {
return protoreflect.Value{}, fmt.Errorf("field %s, received additional keys that are not in the schema, %v", fd.FullName(), unmappedKeys)
}

le.Append(protoreflect.ValueOfMessage(childMsgTarget))
}

return protoreflect.ValueOfList(le), nil
}

// listKeyAsProtoValue converts the value of a list key (represented as a string) into a protoreflect.Value that can be
// used to set a scalar protobuf field.
func listKeyAsProtoValue(fd protoreflect.FieldDescriptor, val string) (protoreflect.Value, error) {
switch fd.Kind() {
case protoreflect.Uint64Kind:
v, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("invalid uint64 value %v, err: %v", val, err)
}
return protoreflect.ValueOfUint64(v), nil
case protoreflect.StringKind:
return protoreflect.ValueOfString(val), nil
default:
return protoreflect.Value{}, fmt.Errorf("unsupported or invalid kind %v", fd.Kind())
}
}

// hasIgnoreExtraPaths checks whether the supplied opts slice contains the
// ignoreExtraPaths option.
func hasIgnoreExtraPaths(opts []UnmapOpt) bool {
Expand Down
Loading

0 comments on commit 2224a32

Please sign in to comment.