Skip to content

Commit

Permalink
Consumes request bodies (#263)
Browse files Browse the repository at this point in the history
Co-authored-by: Francis Lennon <francis.lennon@whitehatsec.com>
Co-authored-by: Pierre Fenoll <pierrefenoll@gmail.com>
  • Loading branch information
3 people committed Nov 14, 2020
1 parent 738fe87 commit 2e4cbb2
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 67 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ jobs:
go-version: 1.x
- run: go version

- run: go get ./...
- run: go mod download && go mod verify
- run: go test ./...
- run: go vet ./...
- run: go fmt ./...
- run: git --no-pager diff && [[ $(git --no-pager diff --name-only | wc -l) = 0 ]]
Expand Down
1 change: 1 addition & 0 deletions openapi2/openapi2.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Swagger struct {
Info openapi3.Info `json:"info"`
ExternalDocs *openapi3.ExternalDocs `json:"externalDocs,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Host string `json:"host,omitempty"`
BasePath string `json:"basePath,omitempty"`
Paths map[string]*PathItem `json:"paths,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion openapi2conv/issue187_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ paths:
get:
requestBody:
content:
application/json:
'*/*':
schema:
$ref: '#/components/schemas/TestRef'
responses:
Expand Down
170 changes: 108 additions & 62 deletions openapi2conv/openapi2_conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/url"
"sort"
"strings"

"github.com/getkin/kin-openapi/openapi2"
Expand Down Expand Up @@ -46,7 +47,7 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) {
result.Components.Parameters = make(map[string]*openapi3.ParameterRef)
result.Components.RequestBodies = make(map[string]*openapi3.RequestBodyRef)
for k, parameter := range parameters {
v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(&result.Components, parameter)
v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(&result.Components, parameter, swagger.Consumes)
switch {
case err != nil:
return nil, err
Expand All @@ -65,7 +66,7 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) {
if paths := swagger.Paths; len(paths) != 0 {
resultPaths := make(map[string]*openapi3.PathItem, len(paths))
for path, pathItem := range paths {
r, err := ToV3PathItem(swagger, &result.Components, pathItem)
r, err := ToV3PathItem(swagger, &result.Components, pathItem, swagger.Consumes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -111,20 +112,20 @@ func ToV3Swagger(swagger *openapi2.Swagger) (*openapi3.Swagger, error) {
return result, nil
}

func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem) (*openapi3.PathItem, error) {
func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, consumes []string) (*openapi3.PathItem, error) {
stripNonCustomExtensions(pathItem.Extensions)
result := &openapi3.PathItem{
ExtensionProps: pathItem.ExtensionProps,
}
for method, operation := range pathItem.Operations() {
resultOperation, err := ToV3Operation(swagger, components, pathItem, operation)
resultOperation, err := ToV3Operation(swagger, components, pathItem, operation, consumes)
if err != nil {
return nil, err
}
result.SetOperation(method, resultOperation)
}
for _, parameter := range pathItem.Parameters {
v3Parameter, v3RequestBody, v3Schema, err := ToV3Parameter(components, parameter)
v3Parameter, v3RequestBody, v3Schema, err := ToV3Parameter(components, parameter, consumes)
switch {
case err != nil:
return nil, err
Expand All @@ -139,7 +140,7 @@ func ToV3PathItem(swagger *openapi2.Swagger, components *openapi3.Components, pa
return result, nil
}

func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, operation *openapi2.Operation) (*openapi3.Operation, error) {
func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, pathItem *openapi2.PathItem, operation *openapi2.Operation, consumes []string) (*openapi3.Operation, error) {
if operation == nil {
return nil, nil
}
Expand All @@ -156,10 +157,14 @@ func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, p
result.Security = &resultSecurity
}

if len(operation.Consumes) > 0 {
consumes = operation.Consumes
}

var reqBodies []*openapi3.RequestBodyRef
formDataSchemas := make(map[string]*openapi3.SchemaRef)
for _, parameter := range operation.Parameters {
v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(components, parameter)
v3Parameter, v3RequestBody, v3SchemaMap, err := ToV3Parameter(components, parameter, consumes)
switch {
case err != nil:
return nil, err
Expand All @@ -174,7 +179,7 @@ func ToV3Operation(swagger *openapi2.Swagger, components *openapi3.Components, p
}
}
var err error
if result.RequestBody, err = onlyOneReqBodyParam(reqBodies, formDataSchemas, components); err != nil {
if result.RequestBody, err = onlyOneReqBodyParam(reqBodies, formDataSchemas, components, consumes); err != nil {
return nil, err
}

Expand All @@ -199,7 +204,7 @@ func getParameterNameFromOldRef(ref string) string {
return pathSections[0]
}

func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Parameter) (*openapi3.ParameterRef, *openapi3.RequestBodyRef, map[string]*openapi3.SchemaRef, error) {
func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Parameter, consumes []string) (*openapi3.ParameterRef, *openapi3.RequestBodyRef, map[string]*openapi3.SchemaRef, error) {
if ref := parameter.Ref; ref != "" {
if strings.HasPrefix(ref, "#/parameters/") {
name := getParameterNameFromOldRef(ref)
Expand Down Expand Up @@ -233,7 +238,7 @@ func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Paramete

if schemaRef := parameter.Schema; schemaRef != nil {
// Assuming JSON
result.WithJSONSchemaRef(ToV3SchemaRef(schemaRef))
result.WithSchemaRef(ToV3SchemaRef(schemaRef), consumes)
}
return nil, &openapi3.RequestBodyRef{Value: result}, nil, nil

Expand Down Expand Up @@ -314,7 +319,7 @@ func ToV3Parameter(components *openapi3.Components, parameter *openapi2.Paramete
}
}

func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool) *openapi3.RequestBodyRef {
func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool, consumes []string) *openapi3.RequestBodyRef {
if len(bodies) != len(reqs) {
panic(`request bodies and them being required must match`)
}
Expand All @@ -333,7 +338,7 @@ func formDataBody(bodies map[string]*openapi3.SchemaRef, reqs map[string]bool) *
Required: requireds,
}
return &openapi3.RequestBodyRef{
Value: openapi3.NewRequestBody().WithFormDataSchema(schema),
Value: openapi3.NewRequestBody().WithSchema(schema, consumes),
}
}

Expand All @@ -344,7 +349,7 @@ func getParameterNameFromNewRef(ref string) string {
return pathSections[0]
}

func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[string]*openapi3.SchemaRef, components *openapi3.Components) (*openapi3.RequestBodyRef, error) {
func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[string]*openapi3.SchemaRef, components *openapi3.Components, consumes []string) (*openapi3.RequestBodyRef, error) {
if len(bodies) > 1 {
return nil, errors.New("multiple body parameters cannot exist for the same operation")
}
Expand Down Expand Up @@ -386,7 +391,7 @@ func onlyOneReqBodyParam(bodies []*openapi3.RequestBodyRef, formDataSchemas map[
}
}

return formDataBody(formDataParams, formDataReqs), nil
return formDataBody(formDataParams, formDataReqs, consumes), nil
}

return nil, nil
Expand Down Expand Up @@ -601,25 +606,23 @@ func FromV3Swagger(swagger *openapi3.Swagger) (*openapi2.Swagger, error) {
}
}

for name, requestBody := range swagger.Components.RequestBodies {
parameters := FromV3RequestBodyFormData(requestBody)
for _, param := range parameters {
result.Parameters[param.Name] = param
for name, requestBodyRef := range swagger.Components.RequestBodies {
bodyOrRefParameters, formDataParameters, consumes, err := fromV3RequestBodies(name, requestBodyRef, &swagger.Components)
if err != nil {
return nil, err
}

if len(parameters) == 0 {
paramName := name
if requestBody.Value != nil {
if originalName, ok := requestBody.Value.Extensions["x-originalParamName"]; ok {
json.Unmarshal(originalName.(json.RawMessage), &paramName)
}
if len(formDataParameters) != 0 {
for _, param := range formDataParameters {
result.Parameters[param.Name] = param
}

r, err := FromV3RequestBody(swagger, paramName, requestBody)
if err != nil {
return nil, err
} else if len(bodyOrRefParameters) != 0 {
for _, param := range bodyOrRefParameters {
result.Parameters[name] = param
}
result.Parameters[name] = r
}

if len(consumes) != 0 {
result.Consumes = consumesToArray(consumes)
}
}

Expand All @@ -638,6 +641,52 @@ func FromV3Swagger(swagger *openapi3.Swagger) (*openapi2.Swagger, error) {
return result, nil
}

func consumesToArray(consumes map[string]struct{}) []string {
consumesArr := make([]string, 0, len(consumes))
for key := range consumes {
consumesArr = append(consumesArr, key)
}
sort.Strings(consumesArr)
return consumesArr
}

func fromV3RequestBodies(name string, requestBodyRef *openapi3.RequestBodyRef, components *openapi3.Components) (
bodyOrRefParameters openapi2.Parameters,
formParameters openapi2.Parameters,
consumes map[string]struct{},
err error,
) {
if ref := requestBodyRef.Ref; ref != "" {
bodyOrRefParameters = append(bodyOrRefParameters, &openapi2.Parameter{Ref: FromV3Ref(ref)})
return
}

//Only select one formData or request body for an individual requesstBody as swagger 2 does not support multiples
if requestBodyRef.Value != nil {
for contentType, mediaType := range requestBodyRef.Value.Content {
if consumes == nil {
consumes = make(map[string]struct{})
}
consumes[contentType] = struct{}{}
if formParams := FromV3RequestBodyFormData(mediaType); len(formParams) != 0 {
formParameters = formParams
} else {
paramName := name
if originalName, ok := requestBodyRef.Value.Extensions["x-originalParamName"]; ok {
json.Unmarshal(originalName.(json.RawMessage), &paramName)
}

var r *openapi2.Parameter
if r, err = FromV3RequestBody(paramName, requestBodyRef, mediaType, components); err != nil {
return
}
bodyOrRefParameters = append(bodyOrRefParameters, r)
}
}
}
return
}

func FromV3Schemas(schemas map[string]*openapi3.SchemaRef, components *openapi3.Components) (map[string]*openapi3.SchemaRef, map[string]*openapi2.Parameter) {
v2Defs := make(map[string]*openapi3.SchemaRef)
v2Params := make(map[string]*openapi2.Parameter)
Expand Down Expand Up @@ -711,8 +760,13 @@ func FromV3SchemaRef(schema *openapi3.SchemaRef, components *openapi3.Components
if v := schema.Value.Items; v != nil {
schema.Value.Items, _ = FromV3SchemaRef(v, components)
}
for k, v := range schema.Value.Properties {
schema.Value.Properties[k], _ = FromV3SchemaRef(v, components)
keys := make([]string, 0, len(schema.Value.Properties))
for k := range schema.Value.Properties {
keys = append(keys, k)
}
sort.Strings(keys)
for _, key := range keys {
schema.Value.Properties[key], _ = FromV3SchemaRef(schema.Value.Properties[key], components)
}
if v := schema.Value.AdditionalProperties; v != nil {
schema.Value.AdditionalProperties, _ = FromV3SchemaRef(v, components)
Expand Down Expand Up @@ -770,11 +824,7 @@ nameSearch:
return ""
}

func FromV3RequestBodyFormData(requestBodyRef *openapi3.RequestBodyRef) openapi2.Parameters {
mediaType := requestBodyRef.Value.GetMediaType("multipart/form-data")
if mediaType == nil {
return nil
}
func FromV3RequestBodyFormData(mediaType *openapi3.MediaType) openapi2.Parameters {
parameters := openapi2.Parameters{}
for propName, schemaRef := range mediaType.Schema.Value.Properties {
if ref := schemaRef.Ref; ref != "" {
Expand Down Expand Up @@ -848,27 +898,28 @@ func FromV3Operation(swagger *openapi3.Swagger, operation *openapi3.Operation) (
result.Parameters = append(result.Parameters, r)
}
if v := operation.RequestBody; v != nil {
parameters := FromV3RequestBodyFormData(operation.RequestBody)
if len(parameters) > 0 {
result.Parameters = append(result.Parameters, parameters...)
} else {
// Find parameter name that we can use for the body
name := findNameForRequestBody(operation)
if name == "" {
return nil, errors.New("could not find a name for request body")
}
r, err := FromV3RequestBody(swagger, name, v)
if err != nil {
return nil, err
// Find parameter name that we can use for the body
name := findNameForRequestBody(operation)
if name == "" {
return nil, errors.New("could not find a name for request body")
}

bodyOrRefParameters, formDataParameters, consumes, err := fromV3RequestBodies(name, v, &swagger.Components)
if err != nil {
return nil, err
}
if len(formDataParameters) != 0 {
result.Parameters = append(result.Parameters, formDataParameters...)
} else if len(bodyOrRefParameters) != 0 {
for _, param := range bodyOrRefParameters {
result.Parameters = append(result.Parameters, param)
break // add a single request body
}
result.Parameters = append(result.Parameters, r)

}
}

for _, param := range result.Parameters {
if param.Type == "file" {
result.Consumes = append(result.Consumes, "multipart/form-data")
break
if len(consumes) != 0 {
result.Consumes = consumesToArray(consumes)
}
}

Expand All @@ -882,10 +933,7 @@ func FromV3Operation(swagger *openapi3.Swagger, operation *openapi3.Operation) (
return result, nil
}

func FromV3RequestBody(swagger *openapi3.Swagger, name string, requestBodyRef *openapi3.RequestBodyRef) (*openapi2.Parameter, error) {
if ref := requestBodyRef.Ref; ref != "" {
return &openapi2.Parameter{Ref: FromV3Ref(ref)}, nil
}
func FromV3RequestBody(name string, requestBodyRef *openapi3.RequestBodyRef, mediaType *openapi3.MediaType, components *openapi3.Components) (*openapi2.Parameter, error) {
requestBody := requestBodyRef.Value

stripNonCustomExtensions(requestBody.Extensions)
Expand All @@ -897,10 +945,8 @@ func FromV3RequestBody(swagger *openapi3.Swagger, name string, requestBodyRef *o
ExtensionProps: requestBody.ExtensionProps,
}

// Assuming JSON
mediaType := requestBody.GetMediaType("application/json")
if mediaType != nil {
result.Schema, _ = FromV3SchemaRef(mediaType.Schema, &swagger.Components)
result.Schema, _ = FromV3SchemaRef(mediaType.Schema, components)
}
return result, nil
}
Expand Down
Loading

0 comments on commit 2e4cbb2

Please sign in to comment.