Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented optional (presented) fields for proto3 #770

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
*.js.map

# Conformance test output and transient files.
conformance/failing_tests.txt
conformance/failing_tests.txt

.idea
17 changes: 10 additions & 7 deletions gogoproto/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@

package gogoproto

import google_protobuf "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
import proto "github.com/gogo/protobuf/proto"
import (
proto "github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/proto3optional"
google_protobuf "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
)

func IsEmbed(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Embed, false)
}

func IsNullable(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Nullable, true)
func IsNullable(field *google_protobuf.FieldDescriptorProto, proto3Resolver *proto3optional.Resolver) bool {
return proto.GetBoolExtension(field.Options, E_Nullable, true) || proto3Resolver.IsFakeOneOf(field)
}

func IsStdTime(field *google_protobuf.FieldDescriptorProto) bool {
Expand Down Expand Up @@ -96,12 +99,12 @@ func IsWktPtr(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false)
}

func NeedsNilCheck(proto3 bool, field *google_protobuf.FieldDescriptorProto) bool {
nullable := IsNullable(field)
func NeedsNilCheck(field *google_protobuf.FieldDescriptorProto, proto3Resolver *proto3optional.Resolver) bool {
nullable := IsNullable(field, proto3Resolver)
if field.IsMessage() || IsCustomType(field) {
return nullable
}
if proto3 {
if proto3Resolver.IsProto3WithoutOptional(field) {
return false
}
return nullable || *field.Type == google_protobuf.FieldDescriptorProto_TYPE_BYTES
Expand Down
41 changes: 22 additions & 19 deletions plugin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ package compare
import (
"github.com/gogo/protobuf/gogoproto"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/proto3optional"
descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
"github.com/gogo/protobuf/protoc-gen-gogo/generator"
"github.com/gogo/protobuf/vanity"
Expand Down Expand Up @@ -142,12 +143,11 @@ func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
p.P(`}`)
}

func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
fieldname := p.GetOneOfFieldName(message, field)
func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto, proto3Resolver *proto3optional.Resolver) {
fieldname := p.GetOneOfFieldName(message, field, proto3Resolver)
repeated := field.IsRepeated()
ctype := gogoproto.IsCustomType(field)
nullable := gogoproto.IsNullable(field)
nullable := gogoproto.IsNullable(field, proto3Resolver)
// oneof := field.OneofIndex != nil
if !repeated {
if ctype {
Expand Down Expand Up @@ -190,7 +190,7 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
p.Out()
p.P(`}`)
} else if field.IsString() {
if nullable && !proto3 {
if nullable && !proto3Resolver.IsProto3WithoutOptional(field) {
p.generateNullableField(fieldname)
} else {
p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
Expand All @@ -205,7 +205,7 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
p.P(`}`)
}
} else if field.IsBool() {
if nullable && !proto3 {
if nullable && !proto3Resolver.IsProto3WithoutOptional(field) {
p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
p.In()
p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
Expand Down Expand Up @@ -241,7 +241,7 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
p.P(`}`)
}
} else {
if nullable && !proto3 {
if nullable && !proto3Resolver.IsProto3WithoutOptional(field) {
p.generateNullableField(fieldname)
} else {
p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
Expand Down Expand Up @@ -278,10 +278,10 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
p.P(`}`)
} else {
if p.IsMap(field) {
m := p.GoMapType(nil, field)
valuegoTyp, _ := p.GoType(nil, m.ValueField)
valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
m := p.GoMapType(nil, field, proto3Resolver)
valuegoTyp, _ := p.GoType(nil, m.ValueField, proto3Resolver)
valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField, proto3Resolver)
nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp, proto3Resolver)

mapValue := m.ValueAliasField
if mapValue.IsMessage() || p.IsGroup(mapValue) {
Expand Down Expand Up @@ -402,12 +402,15 @@ func (p *plugin) generateMessage(file *generator.FileDescriptor, message *genera
p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
p.In()
p.generateMsgNullAndTypeCheck(ccTypeName)

proto3Resolver := proto3optional.NewResolver(gogoproto.IsProto3(file.FileDescriptorProto), message.Field)

oneofs := make(map[string]struct{})

for _, field := range message.Field {
oneof := field.OneofIndex != nil
oneof := proto3Resolver.IsRealOneOf(field)
if oneof {
fieldname := p.GetFieldName(message, field)
fieldname := p.GetFieldName(message, field, proto3Resolver)
if _, ok := oneofs[fieldname]; ok {
continue
} else {
Expand Down Expand Up @@ -435,7 +438,7 @@ func (p *plugin) generateMessage(file *generator.FileDescriptor, message *genera
p.P(`switch this.`, fieldname, `.(type) {`)
for i, subfield := range message.Field {
if *subfield.OneofIndex == *field.OneofIndex {
ccTypeName := p.OneOfTypeName(message, subfield)
ccTypeName := p.OneOfTypeName(message, subfield, proto3Resolver)
p.P(`case *`, ccTypeName, `:`)
p.In()
p.P(`thisType = `, i)
Expand All @@ -452,7 +455,7 @@ func (p *plugin) generateMessage(file *generator.FileDescriptor, message *genera
p.P(`switch that1.`, fieldname, `.(type) {`)
for i, subfield := range message.Field {
if *subfield.OneofIndex == *field.OneofIndex {
ccTypeName := p.OneOfTypeName(message, subfield)
ccTypeName := p.OneOfTypeName(message, subfield, proto3Resolver)
p.P(`case *`, ccTypeName, `:`)
p.In()
p.P(`that1Type = `, i)
Expand Down Expand Up @@ -485,7 +488,7 @@ func (p *plugin) generateMessage(file *generator.FileDescriptor, message *genera
p.Out()
p.P(`}`)
} else {
p.generateField(file, message, field)
p.generateField(file, message, field, proto3Resolver)
}
}
if message.DescriptorProto.HasExtension() {
Expand Down Expand Up @@ -557,17 +560,17 @@ func (p *plugin) generateMessage(file *generator.FileDescriptor, message *genera
//Generate Compare methods for oneof fields
m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
for _, field := range m.Field {
oneof := field.OneofIndex != nil
oneof := proto3Resolver.IsRealOneOf(field)
if !oneof {
continue
}
ccTypeName := p.OneOfTypeName(message, field)
ccTypeName := p.OneOfTypeName(message, field, proto3Resolver)
p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
p.In()

p.generateMsgNullAndTypeCheck(ccTypeName)
vanity.TurnOffNullableForNativeTypes(field)
p.generateField(file, message, field)
p.generateField(file, message, field, proto3Resolver)

p.P(`return 0`)
p.Out()
Expand Down
15 changes: 9 additions & 6 deletions plugin/defaultcheck/defaultcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ It is enabled by the following extensions:

For incorrect usage of nullable with tests see:

github.com/gogo/protobuf/test/nullableconflict

github.com/gogo/protobuf/test/nullableconflict
*/
package defaultcheck

import (
"fmt"
"github.com/gogo/protobuf/gogoproto"
"github.com/gogo/protobuf/proto3optional"
"github.com/gogo/protobuf/protoc-gen-gogo/generator"
"os"
)
Expand All @@ -75,10 +75,11 @@ func (p *plugin) Init(g *generator.Generator) {
}

func (p *plugin) Generate(file *generator.FileDescriptor) {
proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
for _, msg := range file.Messages() {
getters := gogoproto.HasGoGetters(file.FileDescriptorProto, msg.DescriptorProto)
face := gogoproto.IsFace(file.FileDescriptorProto, msg.DescriptorProto)
proto3Resolver := proto3optional.NewResolver(gogoproto.IsProto3(file.FileDescriptorProto), msg.Field)

for _, field := range msg.GetField() {
if len(field.GetDefaultValue()) > 0 {
if !getters {
Expand All @@ -90,7 +91,7 @@ func (p *plugin) Generate(file *generator.FileDescriptor) {
os.Exit(1)
}
}
if gogoproto.IsNullable(field) {
if gogoproto.IsNullable(field, proto3Resolver) {
continue
}
if len(field.GetDefaultValue()) > 0 {
Expand All @@ -100,7 +101,7 @@ func (p *plugin) Generate(file *generator.FileDescriptor) {
if !field.IsMessage() && !gogoproto.IsCustomType(field) {
if field.IsRepeated() {
fmt.Fprintf(os.Stderr, "WARNING: field %v.%v is a repeated non-nullable native type, nullable=false has no effect\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name))
} else if proto3 {
} else if proto3Resolver.IsProto3WithoutOptional(field) {
fmt.Fprintf(os.Stderr, "ERROR: field %v.%v is a native type and in proto3 syntax with nullable=false there exists conflicting implementations when encoding zero values", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name))
os.Exit(1)
}
Expand All @@ -118,8 +119,10 @@ func (p *plugin) Generate(file *generator.FileDescriptor) {
}
}
}

proto3Resolver := proto3optional.NewResolver(gogoproto.IsProto3(file.FileDescriptorProto), file.GetExtension())
for _, e := range file.GetExtension() {
if !gogoproto.IsNullable(e) {
if !gogoproto.IsNullable(e, proto3Resolver) {
fmt.Fprintf(os.Stderr, "ERROR: extended field %v cannot be nullable %v", generator.CamelCase(e.GetName()), generator.CamelCase(*e.Name))
os.Exit(1)
}
Expand Down
9 changes: 4 additions & 5 deletions plugin/embedcheck/embedcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ It is enabled by the following extensions:

For incorrect usage of embed with tests see:

github.com/gogo/protobuf/test/embedconflict

github.com/gogo/protobuf/test/embedconflict
*/
package embedcheck

Expand Down Expand Up @@ -109,7 +108,7 @@ func (p *plugin) Generate(file *generator.FileDescriptor) {
p.checkOverwrite(msg, os)
}
}
p.checkNameSpace(msg)
p.checkNameSpace(msg, gogoproto.IsProto3(file.FileDescriptorProto))
for _, field := range msg.GetField() {
if gogoproto.IsEmbed(field) && gogoproto.IsCustomName(field) {
fmt.Fprintf(os.Stderr, "ERROR: field %v with custom name %v cannot be embedded", *field.Name, gogoproto.GetCustomName(field))
Expand All @@ -126,14 +125,14 @@ func (p *plugin) Generate(file *generator.FileDescriptor) {
}
}

func (p *plugin) checkNameSpace(message *generator.Descriptor) map[string]bool {
func (p *plugin) checkNameSpace(message *generator.Descriptor, proto3 bool) map[string]bool {
ccTypeName := generator.CamelCaseSlice(message.TypeName())
names := make(map[string]bool)
for _, field := range message.Field {
fieldname := generator.CamelCase(*field.Name)
if field.IsMessage() && gogoproto.IsEmbed(field) {
desc := p.ObjectNamed(field.GetTypeName())
moreNames := p.checkNameSpace(desc.(*generator.Descriptor))
moreNames := p.checkNameSpace(desc.(*generator.Descriptor), proto3)
for another := range moreNames {
if names[another] {
fmt.Fprintf(os.Stderr, "ERROR: duplicate embedded fieldname %v in type %v\n", fieldname, ccTypeName)
Expand Down
Loading