From dfd4412007ac9e6d1b0a14e4d2624f22d40e0a86 Mon Sep 17 00:00:00 2001 From: Zaak Date: Thu, 9 Nov 2023 19:37:03 +0800 Subject: [PATCH] fix: Select the corret schemas that correspond to the messages used fixes #392 --- cmd/protoc-gen-openapi/generator/generator.go | 35 ++++++++++++++----- cmd/protoc-gen-openapi/generator/reflector.go | 25 ++++++++++--- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/cmd/protoc-gen-openapi/generator/generator.go b/cmd/protoc-gen-openapi/generator/generator.go index e548ab21..db30343f 100644 --- a/cmd/protoc-gen-openapi/generator/generator.go +++ b/cmd/protoc-gen-openapi/generator/generator.go @@ -133,11 +133,15 @@ func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document { // While we have required schemas left to generate, go through the files again // looking for the related message and adding them to the document if required. for len(g.reflect.requiredSchemas) > 0 { - count := len(g.reflect.requiredSchemas) for _, file := range g.plugin.Files { g.addSchemasForMessagesToDocumentV3(d, file.Messages) } - g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)] + // clear the generated schemas + for schema := range g.reflect.requiredSchemas { + if contains(g.generatedSchemas, schema) { + delete(g.reflect.requiredSchemas, schema) + } + } } // If there is only 1 service, then use it's title for the @@ -771,12 +775,14 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr } } -// addSchemaForMessageToDocumentV3 adds the schema to the document if required +// addSchemaToDocumentV3 adds the schema to the document if required func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) { - if contains(g.generatedSchemas, schema.Name) { - return + // check if schema already exists in Schemas, instead of checking "generated" + for _, prop := range d.Components.Schemas.AdditionalProperties { + if prop.Name == schema.Name { + return + } } - g.generatedSchemas = append(g.generatedSchemas, schema.Name) d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema) } @@ -789,12 +795,25 @@ func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, m } schemaName := g.reflect.formatMessageName(message.Desc) + fqSchemaName := g.reflect.formatPackageMessageName(message.Desc) // Only generate this if we need it and haven't already generated it. - if !contains(g.reflect.requiredSchemas, schemaName) || - contains(g.generatedSchemas, schemaName) { + requiredFQSchema, ok := g.reflect.requiredSchemas[schemaName] + if !ok { + continue + } else if requiredFQSchema != fqSchemaName { + // "schemaName" with same name is required, but it's not the actual + // schema with "fqSchemaName". Try to use the fully-qualified schema. + if _, ok = g.reflect.requiredSchemas[fqSchemaName]; !ok { + continue + } + // use fully-qualified name as schema name if there are same named messages + schemaName = fqSchemaName + } + if contains(g.generatedSchemas, schemaName) { continue } + g.generatedSchemas = append(g.generatedSchemas, schemaName) typeName := g.reflect.fullMessageTypeName(message.Desc) messageDescription := g.filterCommentString(message.Comments.Leading) diff --git a/cmd/protoc-gen-openapi/generator/reflector.go b/cmd/protoc-gen-openapi/generator/reflector.go index 31a0f930..5bce4522 100644 --- a/cmd/protoc-gen-openapi/generator/reflector.go +++ b/cmd/protoc-gen-openapi/generator/reflector.go @@ -33,7 +33,9 @@ const ( type OpenAPIv3Reflector struct { conf Configuration - requiredSchemas []string // Names of schemas which are used through references. + // Names of schemas which are used through references. + // map: schema name will be used actually -> fully-qualified schema name + requiredSchemas map[string]string } // NewOpenAPIv3Reflector creates a new reflector. @@ -41,7 +43,7 @@ func NewOpenAPIv3Reflector(conf Configuration) *OpenAPIv3Reflector { return &OpenAPIv3Reflector{ conf: conf, - requiredSchemas: make([]string, 0), + requiredSchemas: make(map[string]string, 0), } } @@ -86,6 +88,14 @@ func (r *OpenAPIv3Reflector) formatMessageName(message protoreflect.MessageDescr return name } +// formatPackageMessageName returns the fully-qualified name of a message. +func (r *OpenAPIv3Reflector) formatPackageMessageName(message protoreflect.MessageDescriptor) string { + package_name := string(message.ParentFile().Package()) + name := package_name + "." + r.getMessageName(message) + + return name +} + func (r *OpenAPIv3Reflector) formatFieldName(field protoreflect.FieldDescriptor) string { if *r.conf.Naming == "proto" { return string(field.Name()) @@ -116,8 +126,15 @@ func (r *OpenAPIv3Reflector) responseContentForMessage(message protoreflect.Mess func (r *OpenAPIv3Reflector) schemaReferenceForMessage(message protoreflect.MessageDescriptor) string { schemaName := r.formatMessageName(message) - if !contains(r.requiredSchemas, schemaName) { - r.requiredSchemas = append(r.requiredSchemas, schemaName) + fqSchemaName := r.formatPackageMessageName(message) + requiredFQSchema, ok := r.requiredSchemas[schemaName] + if !ok { + // new required, use schemaName + r.requiredSchemas[schemaName] = fqSchemaName + } else if requiredFQSchema != fqSchemaName { + // use the fully-qualified schema name as there are same named messages + schemaName = fqSchemaName + r.requiredSchemas[schemaName] = fqSchemaName } return "#/components/schemas/" + schemaName }