diff --git a/.github/docs/openapi3filter.txt b/.github/docs/openapi3filter.txt index 094da23ea..ac738e75a 100644 --- a/.github/docs/openapi3filter.txt +++ b/.github/docs/openapi3filter.txt @@ -17,7 +17,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) (err er func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, ...) error func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error func ValidateSecurityRequirements(ctx context.Context, input *RequestValidationInput, ...) error -func ZipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, ...) (interface{}, error) type AuthenticationFunc func(context.Context, *AuthenticationInput) error type AuthenticationInput struct{ ... } type BodyDecoder func(io.Reader, http.Header, *openapi3.SchemaRef, EncodingFn) (interface{}, error) diff --git a/openapi3filter/csv_file_upload_test.go b/openapi3filter/csv_file_upload_test.go new file mode 100644 index 000000000..89efb96d9 --- /dev/null +++ b/openapi3filter/csv_file_upload_test.go @@ -0,0 +1,127 @@ +package openapi3filter_test + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +func TestValidateCsvFileUpload(t *testing.T) { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /test: + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - file + properties: + file: + type: string + format: string + responses: + '200': + description: Created +` + + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(loader.Context) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + tests := []struct { + csvData string + wantErr bool + }{ + { + `foo,bar`, + false, + }, + { + `"foo","bar"`, + false, + }, + { + `foo,bar +baz,qux`, + false, + }, + { + `foo,bar +baz,qux,quux`, + true, + }, + { + `"""`, + true, + }, + } + for _, tt := range tests { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + { // Add file data + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="file"; filename="hello.csv"`) + h.Set("Content-Type", "text/csv") + + fw, err := writer.CreatePart(h) + require.NoError(t, err) + _, err = io.Copy(fw, strings.NewReader(tt.csvData)) + + require.NoError(t, err) + } + + writer.Close() + + req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes())) + require.NoError(t, err) + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + if err = openapi3filter.ValidateRequestBody( + context.Background(), + &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + }, + route.Operation.RequestBody.Value, + ); err != nil { + if !tt.wantErr { + t.Errorf("got %v", err) + } + continue + } + if tt.wantErr { + t.Errorf("want err") + } + } +} diff --git a/openapi3filter/req_resp_decoder.go b/openapi3filter/req_resp_decoder.go index 5853826bd..44abacb0f 100644 --- a/openapi3filter/req_resp_decoder.go +++ b/openapi3filter/req_resp_decoder.go @@ -3,6 +3,7 @@ package openapi3filter import ( "archive/zip" "bytes" + "encoding/csv" "encoding/json" "errors" "fmt" @@ -1013,8 +1014,9 @@ func init() { RegisterBodyDecoder("application/x-www-form-urlencoded", urlencodedBodyDecoder) RegisterBodyDecoder("application/x-yaml", yamlBodyDecoder) RegisterBodyDecoder("application/yaml", yamlBodyDecoder) - RegisterBodyDecoder("application/zip", ZipFileBodyDecoder) + RegisterBodyDecoder("application/zip", zipFileBodyDecoder) RegisterBodyDecoder("multipart/form-data", multipartBodyDecoder) + RegisterBodyDecoder("text/csv", csvBodyDecoder) RegisterBodyDecoder("text/plain", plainBodyDecoder) } @@ -1221,8 +1223,8 @@ func FileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema return string(data), nil } -// ZipFileBodyDecoder is a body decoder that decodes a zip file body to a string. -func ZipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) { +// zipFileBodyDecoder is a body decoder that decodes a zip file body to a string. +func zipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) { buff := bytes.NewBuffer([]byte{}) size, err := io.Copy(buff, body) if err != nil { @@ -1271,3 +1273,23 @@ func ZipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Sch return string(content), nil } + +// csvBodyDecoder is a body decoder that decodes a csv body to a string. +func csvBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) { + r := csv.NewReader(body) + + var content string + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + content += strings.Join(record, ",") + "\n" + } + + return content, nil +} diff --git a/openapi3filter/req_resp_decoder_test.go b/openapi3filter/req_resp_decoder_test.go index 709cdc929..1e71e0ac5 100644 --- a/openapi3filter/req_resp_decoder_test.go +++ b/openapi3filter/req_resp_decoder_test.go @@ -1345,7 +1345,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { } return strings.Split(string(data), ","), nil } - contentType := "text/csv" + contentType := "application/csv" h := make(http.Header) h.Set(headerCT, contentType) @@ -1371,7 +1371,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { _, _, err = decodeBody(body, h, schema, encFn) require.Equal(t, &ParseError{ Kind: KindUnsupportedFormat, - Reason: prefixUnsupportedCT + ` "text/csv"`, + Reason: prefixUnsupportedCT + ` "application/csv"`, }, err) }