From 2e4cbb269bfc6dfb39a79ffcaf671c7aa2d8951a Mon Sep 17 00:00:00 2001 From: FrancisLennon17 Date: Sat, 14 Nov 2020 14:12:13 +0000 Subject: [PATCH] Consumes request bodies (#263) Co-authored-by: Francis Lennon Co-authored-by: Pierre Fenoll --- .github/workflows/go.yml | 3 +- openapi2/openapi2.go | 1 + openapi2conv/issue187_test.go | 2 +- openapi2conv/openapi2_conv.go | 170 ++++++++++++++++++----------- openapi2conv/openapi2_conv_test.go | 24 +++- openapi3/content.go | 26 +++++ openapi3/request_body.go | 10 ++ 7 files changed, 169 insertions(+), 67 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 1d3c890e6..f8f3669fc 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -25,7 +25,8 @@ jobs: go-version: 1.x - run: go version - - run: go get ./... + - run: go mod download && go mod verify + - run: go test ./... - run: go vet ./... - run: go fmt ./... - run: git --no-pager diff && [[ $(git --no-pager diff --name-only | wc -l) = 0 ]] diff --git a/openapi2/openapi2.go b/openapi2/openapi2.go index 247775257..5e4877b96 100644 --- a/openapi2/openapi2.go +++ b/openapi2/openapi2.go @@ -21,6 +21,7 @@ type Swagger struct { Info openapi3.Info `json:"info"` ExternalDocs *openapi3.ExternalDocs `json:"externalDocs,omitempty"` Schemes []string `json:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty"` Host string `json:"host,omitempty"` BasePath string `json:"basePath,omitempty"` Paths map[string]*PathItem `json:"paths,omitempty"` diff --git a/openapi2conv/issue187_test.go b/openapi2conv/issue187_test.go index 979866c34..16a6d1a1c 100644 --- a/openapi2conv/issue187_test.go +++ b/openapi2conv/issue187_test.go @@ -155,7 +155,7 @@ paths: get: requestBody: content: - application/json: + '*/*': schema: $ref: '#/components/schemas/TestRef' responses: diff --git a/openapi2conv/openapi2_conv.go b/openapi2conv/openapi2_conv.go index fbf2b099c..42c4f6f75 100644 --- a/openapi2conv/openapi2_conv.go +++ b/openapi2conv/openapi2_conv.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/url" + "sort" "strings" "github.com/getkin/kin-openapi/openapi2" @@ -46,7 +47,7 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) { result.Components.Parameters = make(map[string]*openapi3.ParameterRef) result.Components.RequestBodies = make(map[string]*openapi3.RequestBodyRef) for k, parameter := range parameters { - v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(&result.Components, parameter) + v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(&result.Components, parameter, swagger.Consumes) switch { case err != nil: return nil, err @@ -65,7 +66,7 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) { if paths := swagger.Paths; len(paths) != 0 { resultPaths := make(map[string]*openapi3.PathItem, len(paths)) for path, pathItem := range paths { - r, err := ToV3PathItem(swagger, &result.Components, pathItem) + r, err := ToV3PathItem(swagger, &result.Components, pathItem, swagger.Consumes) if err != nil { return nil, err } @@ -111,20 +112,20 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) { return result, nil } -func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem) (*openapi3.PathItem, error) { +func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, consumes []string) (*openapi3.PathItem, error) { stripNonCustomExtensions(pathItem.Extensions) result := &openapi3.PathItem{ ExtensionProps: pathItem.ExtensionProps, } for method, operation := range pathItem.Operations() { - resultOperation, err := ToV3Operation(swagger, components, pathItem, operation) + resultOperation, err := ToV3Operation(swagger, components, pathItem, operation, consumes) if err != nil { return nil, err } result.SetOperation(method, resultOperation) } for _, parameter := range pathItem.Parameters { - v3Parameter, v3RequestBody, v3Schema, err := ToV3Parameter(components, parameter) + v3Parameter, v3RequestBody, v3Schema, err := ToV3Parameter(components, parameter, consumes) switch { case err != nil: return nil, err @@ -139,7 +140,7 @@ func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pa return result, nil } -func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, operation *openapi2.Operation) (*openapi3.Operation, error) { +func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, operation *openapi2.Operation, consumes []string) (*openapi3.Operation, error) { if operation == nil { return nil, nil } @@ -156,10 +157,14 @@ func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, p result.Security = &resultSecurity } + if len(operation.Consumes) > 0 { + consumes = operation.Consumes + } + var reqBodies []*openapi3.RequestBodyRef formDataSchemas := make(map[string]*openapi3.SchemaRef) for _, parameter := range operation.Parameters { - v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(components, parameter) + v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(components, parameter, consumes) switch { case err != nil: return nil, err @@ -174,7 +179,7 @@ func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, p } } var err error - if result.RequestBody, err = onlyOneReqBodyParam(reqBodies, formDataSchemas, components); err != nil { + if result.RequestBody, err = onlyOneReqBodyParam(reqBodies, formDataSchemas, components, consumes); err != nil { return nil, err } @@ -199,7 +204,7 @@ func getParameterNameFromOldRef(ref string) string { return pathSections[0] } -func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Parameter) (*openapi3.ParameterRef, *openapi3.RequestBodyRef, map[string]*openapi3.SchemaRef, error) { +func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Parameter, consumes []string) (*openapi3.ParameterRef, *openapi3.RequestBodyRef, map[string]*openapi3.SchemaRef, error) { if ref := parameter.Ref; ref != "" { if strings.HasPrefix(ref, "#/parameters/") { name := getParameterNameFromOldRef(ref) @@ -233,7 +238,7 @@ func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Paramete if schemaRef := parameter.Schema; schemaRef != nil { // Assuming JSON - result.WithJSONSchemaRef(ToV3SchemaRef(schemaRef)) + result.WithSchemaRef(ToV3SchemaRef(schemaRef), consumes) } return nil, &openapi3.RequestBodyRef{Value: result}, nil, nil @@ -314,7 +319,7 @@ func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Paramete } } -func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool) *openapi3.RequestBodyRef { +func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool, consumes []string) *openapi3.RequestBodyRef { if len(bodies) != len(reqs) { panic(`request bodies and them being required must match`) } @@ -333,7 +338,7 @@ func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool) * Required: requireds, } return &openapi3.RequestBodyRef{ - Value: openapi3.NewRequestBody().WithFormDataSchema(schema), + Value: openapi3.NewRequestBody().WithSchema(schema, consumes), } } @@ -344,7 +349,7 @@ func getParameterNameFromNewRef(ref string) string { return pathSections[0] } -func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[string]*openapi3.SchemaRef, components *openapi3.Components) (*openapi3.RequestBodyRef, error) { +func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[string]*openapi3.SchemaRef, components *openapi3.Components, consumes []string) (*openapi3.RequestBodyRef, error) { if len(bodies) > 1 { return nil, errors.New("multiple body parameters cannot exist for the same operation") } @@ -386,7 +391,7 @@ func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[ } } - return formDataBody(formDataParams, formDataReqs), nil + return formDataBody(formDataParams, formDataReqs, consumes), nil } return nil, nil @@ -601,25 +606,23 @@ func FromV3Swagger(swagger *openapi3.Swagger) (*openapi2.Swagger, error) { } } - for name, requestBody := range swagger.Components.RequestBodies { - parameters := FromV3RequestBodyFormData(requestBody) - for _, param := range parameters { - result.Parameters[param.Name] = param + for name, requestBodyRef := range swagger.Components.RequestBodies { + bodyOrRefParameters, formDataParameters, consumes, err := fromV3RequestBodies(name, requestBodyRef, &swagger.Components) + if err != nil { + return nil, err } - - if len(parameters) == 0 { - paramName := name - if requestBody.Value != nil { - if originalName, ok := requestBody.Value.Extensions["x-originalParamName"]; ok { - json.Unmarshal(originalName.(json.RawMessage), ¶mName) - } + if len(formDataParameters) != 0 { + for _, param := range formDataParameters { + result.Parameters[param.Name] = param } - - r, err := FromV3RequestBody(swagger, paramName, requestBody) - if err != nil { - return nil, err + } else if len(bodyOrRefParameters) != 0 { + for _, param := range bodyOrRefParameters { + result.Parameters[name] = param } - result.Parameters[name] = r + } + + if len(consumes) != 0 { + result.Consumes = consumesToArray(consumes) } } @@ -638,6 +641,52 @@ func FromV3Swagger(swagger *openapi3.Swagger) (*openapi2.Swagger, error) { return result, nil } +func consumesToArray(consumes map[string]struct{}) []string { + consumesArr := make([]string, 0, len(consumes)) + for key := range consumes { + consumesArr = append(consumesArr, key) + } + sort.Strings(consumesArr) + return consumesArr +} + +func fromV3RequestBodies(name string, requestBodyRef *openapi3.RequestBodyRef, components *openapi3.Components) ( + bodyOrRefParameters openapi2.Parameters, + formParameters openapi2.Parameters, + consumes map[string]struct{}, + err error, +) { + if ref := requestBodyRef.Ref; ref != "" { + bodyOrRefParameters = append(bodyOrRefParameters, &openapi2.Parameter{Ref: FromV3Ref(ref)}) + return + } + + //Only select one formData or request body for an individual requesstBody as swagger 2 does not support multiples + if requestBodyRef.Value != nil { + for contentType, mediaType := range requestBodyRef.Value.Content { + if consumes == nil { + consumes = make(map[string]struct{}) + } + consumes[contentType] = struct{}{} + if formParams := FromV3RequestBodyFormData(mediaType); len(formParams) != 0 { + formParameters = formParams + } else { + paramName := name + if originalName, ok := requestBodyRef.Value.Extensions["x-originalParamName"]; ok { + json.Unmarshal(originalName.(json.RawMessage), ¶mName) + } + + var r *openapi2.Parameter + if r, err = FromV3RequestBody(paramName, requestBodyRef, mediaType, components); err != nil { + return + } + bodyOrRefParameters = append(bodyOrRefParameters, r) + } + } + } + return +} + func FromV3Schemas(schemas map[string]*openapi3.SchemaRef, components *openapi3.Components) (map[string]*openapi3.SchemaRef, map[string]*openapi2.Parameter) { v2Defs := make(map[string]*openapi3.SchemaRef) v2Params := make(map[string]*openapi2.Parameter) @@ -711,8 +760,13 @@ func FromV3SchemaRef(schema *openapi3.SchemaRef, components *openapi3.Components if v := schema.Value.Items; v != nil { schema.Value.Items, _ = FromV3SchemaRef(v, components) } - for k, v := range schema.Value.Properties { - schema.Value.Properties[k], _ = FromV3SchemaRef(v, components) + keys := make([]string, 0, len(schema.Value.Properties)) + for k := range schema.Value.Properties { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + schema.Value.Properties[key], _ = FromV3SchemaRef(schema.Value.Properties[key], components) } if v := schema.Value.AdditionalProperties; v != nil { schema.Value.AdditionalProperties, _ = FromV3SchemaRef(v, components) @@ -770,11 +824,7 @@ nameSearch: return "" } -func FromV3RequestBodyFormData(requestBodyRef *openapi3.RequestBodyRef) openapi2.Parameters { - mediaType := requestBodyRef.Value.GetMediaType("multipart/form-data") - if mediaType == nil { - return nil - } +func FromV3RequestBodyFormData(mediaType *openapi3.MediaType) openapi2.Parameters { parameters := openapi2.Parameters{} for propName, schemaRef := range mediaType.Schema.Value.Properties { if ref := schemaRef.Ref; ref != "" { @@ -848,27 +898,28 @@ func FromV3Operation(swagger *openapi3.Swagger, operation *openapi3.Operation) ( result.Parameters = append(result.Parameters, r) } if v := operation.RequestBody; v != nil { - parameters := FromV3RequestBodyFormData(operation.RequestBody) - if len(parameters) > 0 { - result.Parameters = append(result.Parameters, parameters...) - } else { - // Find parameter name that we can use for the body - name := findNameForRequestBody(operation) - if name == "" { - return nil, errors.New("could not find a name for request body") - } - r, err := FromV3RequestBody(swagger, name, v) - if err != nil { - return nil, err + // Find parameter name that we can use for the body + name := findNameForRequestBody(operation) + if name == "" { + return nil, errors.New("could not find a name for request body") + } + + bodyOrRefParameters, formDataParameters, consumes, err := fromV3RequestBodies(name, v, &swagger.Components) + if err != nil { + return nil, err + } + if len(formDataParameters) != 0 { + result.Parameters = append(result.Parameters, formDataParameters...) + } else if len(bodyOrRefParameters) != 0 { + for _, param := range bodyOrRefParameters { + result.Parameters = append(result.Parameters, param) + break // add a single request body } - result.Parameters = append(result.Parameters, r) + } - } - for _, param := range result.Parameters { - if param.Type == "file" { - result.Consumes = append(result.Consumes, "multipart/form-data") - break + if len(consumes) != 0 { + result.Consumes = consumesToArray(consumes) } } @@ -882,10 +933,7 @@ func FromV3Operation(swagger *openapi3.Swagger, operation *openapi3.Operation) ( return result, nil } -func FromV3RequestBody(swagger *openapi3.Swagger, name string, requestBodyRef *openapi3.RequestBodyRef) (*openapi2.Parameter, error) { - if ref := requestBodyRef.Ref; ref != "" { - return &openapi2.Parameter{Ref: FromV3Ref(ref)}, nil - } +func FromV3RequestBody(name string, requestBodyRef *openapi3.RequestBodyRef, mediaType *openapi3.MediaType, components *openapi3.Components) (*openapi2.Parameter, error) { requestBody := requestBodyRef.Value stripNonCustomExtensions(requestBody.Extensions) @@ -897,10 +945,8 @@ func FromV3RequestBody(swagger *openapi3.Swagger, name string, requestBodyRef *o ExtensionProps: requestBody.ExtensionProps, } - // Assuming JSON - mediaType := requestBody.GetMediaType("application/json") if mediaType != nil { - result.Schema, _ = FromV3SchemaRef(mediaType.Schema, &swagger.Components) + result.Schema, _ = FromV3SchemaRef(mediaType.Schema, components) } return result, nil } diff --git a/openapi2conv/openapi2_conv_test.go b/openapi2conv/openapi2_conv_test.go index c5379ae06..ad3a3e8b7 100644 --- a/openapi2conv/openapi2_conv_test.go +++ b/openapi2conv/openapi2_conv_test.go @@ -88,6 +88,10 @@ const exampleV2 = ` "version": "0.1", "x-info": "info extension" }, + "consumes": [ + "application/json", + "application/xml" + ], "parameters": { "banana": { "in": "path", @@ -224,15 +228,19 @@ const exampleV2 = ` } }, "patch": { + "consumes": [ + "application/json", + "application/xml" + ], "description": "example patch", "parameters": [ { "in": "body", - "name": "body", + "name": "patch_body", "schema": { "allOf": [{"$ref": "#/definitions/Item"}] }, - "x-originalParamName":"body", + "x-originalParamName":"patch_body", "x-requestBody": "requestbody extension 1" } ], @@ -345,6 +353,11 @@ const exampleV3 = ` "schema": { "type": "string" } + }, + "application/xml": { + "schema": { + "type": "string" + } } }, "required": true, @@ -539,9 +552,14 @@ const exampleV3 = ` "schema": { "allOf": [{"$ref": "#/components/schemas/Item"}] } + }, + "application/xml": { + "schema": { + "allOf": [{"$ref": "#/components/schemas/Item"}] + } } }, - "x-originalParamName":"body", + "x-originalParamName":"patch_body", "x-requestBody": "requestbody extension 1" }, "responses": { diff --git a/openapi3/content.go b/openapi3/content.go index f28912c66..abe376e3e 100644 --- a/openapi3/content.go +++ b/openapi3/content.go @@ -12,6 +12,32 @@ func NewContent() Content { return make(map[string]*MediaType, 4) } +func NewContentWithSchema(schema *Schema, consumes []string) Content { + if len(consumes) == 0 { + return Content{ + "*/*": NewMediaType().WithSchema(schema), + } + } + content := make(map[string]*MediaType, len(consumes)) + for _, mediaType := range consumes { + content[mediaType] = NewMediaType().WithSchema(schema) + } + return content +} + +func NewContentWithSchemaRef(schema *SchemaRef, consumes []string) Content { + if len(consumes) == 0 { + return Content{ + "*/*": NewMediaType().WithSchemaRef(schema), + } + } + content := make(map[string]*MediaType, len(consumes)) + for _, mediaType := range consumes { + content[mediaType] = NewMediaType().WithSchemaRef(schema) + } + return content +} + func NewContentWithJSONSchema(schema *Schema) Content { return Content{ "application/json": NewMediaType().WithSchema(schema), diff --git a/openapi3/request_body.go b/openapi3/request_body.go index 56a055ba2..6d649ca4b 100644 --- a/openapi3/request_body.go +++ b/openapi3/request_body.go @@ -33,6 +33,16 @@ func (requestBody *RequestBody) WithContent(content Content) *RequestBody { return requestBody } +func (requestBody *RequestBody) WithSchemaRef(value *SchemaRef, consumes []string) *RequestBody { + requestBody.Content = NewContentWithSchemaRef(value, consumes) + return requestBody +} + +func (requestBody *RequestBody) WithSchema(value *Schema, consumes []string) *RequestBody { + requestBody.Content = NewContentWithSchema(value, consumes) + return requestBody +} + func (requestBody *RequestBody) WithJSONSchemaRef(value *SchemaRef) *RequestBody { requestBody.Content = NewContentWithJSONSchemaRef(value) return requestBody