Skip to content

Commit

Permalink
chore: refactor code (#1539)
Browse files Browse the repository at this point in the history
* chore: refactor code

* chore: lint

* chore: lint

* chore: use embed for generated doc
  • Loading branch information
ubogdan authored Apr 5, 2023
1 parent 7394a48 commit 677b4c2
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 433 deletions.
2 changes: 1 addition & 1 deletion cmd/swag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func initAction(ctx *cli.Context) error {
Tags: ctx.String(tagsFlag),
PackageName: ctx.String(packageName),
Debugger: logger,
OpenAPIVersion: ctx.Bool(openAPIVersionFlag),
GenerateOpenAPI3Doc: ctx.Bool(openAPIVersionFlag),
CollectionFormat: collectionFormat,
})
}
Expand Down
209 changes: 119 additions & 90 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gen
import (
"bufio"
"bytes"
"embed"
"encoding/json"
"fmt"
"go/format"
Expand All @@ -17,8 +18,9 @@ import (

jsoniter "github.com/json-iterator/go"

"github.com/go-openapi/spec"
openapi "github.com/sv-tools/openapi/spec"
v2 "github.com/go-openapi/spec"
v3 "github.com/sv-tools/openapi/spec"

"github.com/swaggo/swag"
"sigs.k8s.io/yaml"
)
Expand All @@ -28,18 +30,20 @@ var open = os.Open
// DefaultOverridesFile is the location swagger will look for type overrides.
const DefaultOverridesFile = ".swaggo"

type genTypeWriter func(*Config, *spec.Swagger) error
type genTypeWriter func(*Config, interface{}) error

// Gen presents a generate tool for swag.
type Gen struct {
json func(data interface{}) ([]byte, error)
jsonIndent func(data interface{}) ([]byte, error)
jsonToYAML func(data []byte) ([]byte, error)
outputTypeMap map[string]genTypeWriter
outputTypeMapV3 map[string]openAPITypeWriter
debug Debugger
json func(data interface{}) ([]byte, error)
jsonIndent func(data interface{}) ([]byte, error)
jsonToYAML func(data []byte) ([]byte, error)
outputTypeMap map[string]genTypeWriter
debug Debugger
}

//go:embed src/*.tmpl
var tmpl embed.FS

// Debugger is the interface that wraps the basic Printf method.
type Debugger interface {
Printf(format string, v ...interface{})
Expand All @@ -50,25 +54,17 @@ func New() *Gen {
gen := Gen{
json: json.Marshal,
jsonIndent: func(data interface{}) ([]byte, error) {
var json = jsoniter.ConfigCompatibleWithStandardLibrary
return json.MarshalIndent(&data, "", " ")
return jsoniter.ConfigCompatibleWithStandardLibrary.MarshalIndent(&data, "", " ")
},
jsonToYAML: yaml.JSONToYAML,
debug: log.New(os.Stdout, "", log.LstdFlags),
}

gen.outputTypeMap = map[string]genTypeWriter{
"go": gen.writeDocSwagger,
"json": gen.writeJSONSwagger,
"yaml": gen.writeYAMLSwagger,
"yml": gen.writeYAMLSwagger,
}

gen.outputTypeMapV3 = map[string]openAPITypeWriter{
"go": gen.writeDocOpenAPI,
"json": gen.writeJSONOpenAPI,
"yaml": gen.writeYAMLOpenAPI,
"yml": gen.writeYAMLOpenAPI,
"go": gen.writeDoc,
"json": gen.writeJSON,
"yaml": gen.writeYAML,
"yml": gen.writeYAML,
}

return &gen
Expand Down Expand Up @@ -139,8 +135,9 @@ type Config struct {
// include only tags mentioned when searching, comma separated
Tags string

// if true, OpenAPI V3.1 spec will be generated
OpenAPIVersion bool
// GenerateOpenAPI3Doc if true, OpenAPI V3.1 spec will be generated
GenerateOpenAPI3Doc bool

// PackageName defines package name of generated `docs.go`
PackageName string

Expand Down Expand Up @@ -196,7 +193,7 @@ func (g *Gen) Build(config *Config) error {
swag.SetOverrides(overrides),
swag.ParseUsingGoList(config.ParseGoList),
swag.SetTags(config.Tags),
swag.SetOpenAPIVersion(config.OpenAPIVersion),
swag.GenerateOpenAPI3Doc(config.GenerateOpenAPI3Doc),
swag.SetCollectionFormat(config.CollectionFormat),
)

Expand All @@ -213,45 +210,18 @@ func (g *Gen) Build(config *Config) error {
return err
}

if config.OpenAPIVersion {
openAPI := p.GetOpenAPI()
err := g.writeOpenAPI(config, openAPI)
if err != nil {
return err
}

return nil
}

swagger := p.GetSwagger()
err := g.writeSwagger(config, swagger)
if err != nil {
return err
if config.GenerateOpenAPI3Doc {
return g.writeOpenAPI(config, p.GetOpenAPI())
}

return nil
return g.writeOpenAPI(config, p.GetSwagger())
}

func (g *Gen) writeOpenAPI(config *Config, o *openapi.OpenAPI) error {
for _, outputType := range config.OutputTypes {
outputType = strings.ToLower(strings.TrimSpace(outputType))
if typeWriter, ok := g.outputTypeMapV3[outputType]; ok {
if err := typeWriter(config, o); err != nil {
return err
}
} else {
log.Printf("output type '%s' not supported", outputType)
}
}

return nil
}

func (g *Gen) writeSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeOpenAPI(config *Config, doc interface{}) error {
for _, outputType := range config.OutputTypes {
outputType = strings.ToLower(strings.TrimSpace(outputType))
if typeWriter, ok := g.outputTypeMap[outputType]; ok {
if err := typeWriter(config, swagger); err != nil {
if err := typeWriter(config, doc); err != nil {
return err
}
} else {
Expand All @@ -262,7 +232,7 @@ func (g *Gen) writeSwagger(config *Config, swagger *spec.Swagger) error {
return nil
}

func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeDoc(config *Config, doc interface{}) error {
var filename = "docs.go"

if config.InstanceName != swag.Name {
Expand Down Expand Up @@ -291,17 +261,25 @@ func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
defer docs.Close()

// Write doc
err = g.writeGoDoc(packageName, docs, swagger, config)
if err != nil {
return err
}
switch spec := doc.(type) {
case *v2.Swagger:
err = g.writeGoDoc(packageName, docs, spec, config)
if err != nil {
return err

}
case *v3.OpenAPI:
err = g.writeGoDocV3(packageName, docs, spec, config)
if err != nil {
return nil
}
}
g.debug.Printf("create docs.go at %+v", docFileName)

return nil
}

func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeJSON(config *Config, spec interface{}) error {
var filename = "swagger.json"

if config.InstanceName != swag.Name {
Expand All @@ -310,7 +288,7 @@ func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {

jsonFileName := path.Join(config.OutputDir, filename)

b, err := g.jsonIndent(swagger)
b, err := g.jsonIndent(spec)
if err != nil {
return err
}
Expand All @@ -325,7 +303,7 @@ func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
return nil
}

func (g *Gen) writeYAMLSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeYAML(config *Config, swagger interface{}) error {
var filename = "swagger.yaml"

if config.InstanceName != swag.Name {
Expand Down Expand Up @@ -421,29 +399,29 @@ func parseOverrides(r io.Reader) (map[string]string, error) {
return overrides, nil
}

func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swagger, config *Config) error {
generator, err := template.New("swagger_info").Funcs(template.FuncMap{
func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *v2.Swagger, config *Config) error {
generator, err := template.New("oas2.tmpl").Funcs(template.FuncMap{
"printDoc": func(v string) string {
// Add schemes
v = "{\n \"schemes\": {{ marshal .Schemes }}," + v[1:]
// Sanitize backticks
return strings.Replace(v, "`", "`+\"`\"+`", -1)
},
}).Parse(packageTemplate)
}).ParseFS(tmpl, "src/*.tmpl")
if err != nil {
return err
}

swaggerSpec := &spec.Swagger{
swaggerSpec := &v2.Swagger{
VendorExtensible: swagger.VendorExtensible,
SwaggerProps: spec.SwaggerProps{
SwaggerProps: v2.SwaggerProps{
ID: swagger.ID,
Consumes: swagger.Consumes,
Produces: swagger.Produces,
Swagger: swagger.Swagger,
Info: &spec.Info{
Info: &v2.Info{
VendorExtensible: swagger.Info.VendorExtensible,
InfoProps: spec.InfoProps{
InfoProps: v2.InfoProps{
Description: "{{escape .Description}}",
Title: "{{.Title}}",
TermsOfService: swagger.Info.TermsOfService,
Expand Down Expand Up @@ -510,27 +488,78 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
return err
}

var packageTemplate = `// Code generated by swaggo/swag{{ if .GeneratedTime }} at {{ .Timestamp }}{{ end }}. DO NOT EDIT.
func (g *Gen) writeGoDocV3(packageName string, output io.Writer, openAPI *v3.OpenAPI, config *Config) error {
generator, err := template.New("oas3.tmpl").Funcs(template.FuncMap{
"printDoc": func(v string) string {
// Add schemes
v = "{\n \"schemes\": {{ marshal .Schemes }}," + v[1:]
// Sanitize backticks
return strings.Replace(v, "`", "`+\"`\"+`", -1)
},
}).ParseFS(tmpl, "src/*.tmpl")
if err != nil {
return err
}

openAPISpec := v3.OpenAPI{
Components: openAPI.Components,
OpenAPI: openAPI.OpenAPI,
Info: &v3.Extendable[v3.Info]{
Spec: &v3.Info{
Description: "{{escape .Description}}",
Title: "{{.Title}}",
Version: "{{.Version}}",
TermsOfService: openAPI.Info.Spec.TermsOfService,
Contact: openAPI.Info.Spec.Contact,
License: openAPI.Info.Spec.License,
Summary: openAPI.Info.Spec.Summary,
},
Extensions: openAPI.Info.Extensions,
},
ExternalDocs: openAPI.ExternalDocs,
Paths: openAPI.Paths,
WebHooks: openAPI.WebHooks,
JsonSchemaDialect: openAPI.JsonSchemaDialect,
Security: openAPI.Security,
Tags: openAPI.Tags,
Servers: openAPI.Servers,
}

package docs
// crafted docs.json
buf, err := g.jsonIndent(openAPISpec)
if err != nil {
return err
}

import "github.com/swaggo/swag"
buffer := &bytes.Buffer{}

const docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = ` + "`{{ printDoc .Doc}}`" + `
err = generator.Execute(buffer, struct {
Timestamp time.Time
Doc string
PackageName string
Title string
Description string
Version string
InstanceName string
GeneratedTime bool
}{
Timestamp: time.Now(),
GeneratedTime: config.GeneratedTime,
Doc: string(buf),
PackageName: packageName,
Title: openAPI.Info.Spec.Title,
Description: openAPI.Info.Spec.Description,
Version: openAPI.Info.Spec.Version,
InstanceName: config.InstanceName,
})
if err != nil {
return err
}

// SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} holds exported Swagger Info so clients can modify it
var SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = &swag.Spec{
Version: {{ printf "%q" .Version}},
Host: {{ printf "%q" .Host}},
BasePath: {{ printf "%q" .BasePath}},
Schemes: []string{ {{ range $index, $schema := .Schemes}}{{if gt $index 0}},{{end}}{{printf "%q" $schema}}{{end}} },
Title: {{ printf "%q" .Title}},
Description: {{ printf "%q" .Description}},
InfoInstanceName: {{ printf "%q" .InstanceName }},
SwaggerTemplate: docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }},
}
code := g.formatSource(buffer.Bytes())

func init() {
swag.Register(SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}.InstanceName(), SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }})
// write
_, err = output.Write(code)

return err
}
`
Loading

0 comments on commit 677b4c2

Please sign in to comment.