From 5a512899b4f002540b1584c551ca6626c771ccfe Mon Sep 17 00:00:00 2001 From: Adam Dobrawy Date: Mon, 31 Jan 2022 12:53:48 +0100 Subject: [PATCH] refactor: extract json_required view decorator (#18170) * refactor: extract json_required view decorator * chore: rename json_required to requires_json * refactor: add requires_form_data decorator and use it * fix: fix lint issue, raise InvalidPayloadFormatError for invalid payload --- superset/annotation_layers/annotations/api.py | 12 ++++--- superset/annotation_layers/api.py | 12 ++++--- superset/charts/api.py | 9 +++--- superset/dashboards/api.py | 9 +++--- superset/dashboards/filter_sets/api.py | 12 ++++--- superset/databases/api.py | 23 +++++++------- superset/datasets/api.py | 9 +++--- superset/explore/form_data/api.py | 8 ++--- superset/key_value/api.py | 8 ++--- superset/queries/saved_queries/api.py | 7 ++++- superset/reports/api.py | 7 ++--- superset/views/base_api.py | 31 ++++++++++++++++++- tests/integration_tests/base_api_tests.py | 15 ++++++++- 13 files changed, 106 insertions(+), 56 deletions(-) diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py index 49172a2bc2eb3..291c074fa358a 100644 --- a/superset/annotation_layers/annotations/api.py +++ b/superset/annotation_layers/annotations/api.py @@ -54,7 +54,11 @@ from superset.annotation_layers.commands.exceptions import AnnotationLayerNotFoundError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.models.annotations import Annotation -from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from superset.views.base_api import ( + BaseSupersetModelRestApi, + requires_json, + statsd_metrics, +) logger = logging.getLogger(__name__) @@ -254,6 +258,7 @@ def get( # pylint: disable=arguments-differ @safe @statsd_metrics @permission_name("post") + @requires_json def post(self, pk: int) -> Response: # pylint: disable=arguments-differ """Creates a new Annotation --- @@ -294,8 +299,6 @@ def post(self, pk: int) -> Response: # pylint: disable=arguments-differ 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) item["layer"] = pk @@ -323,6 +326,7 @@ def post(self, pk: int) -> Response: # pylint: disable=arguments-differ @safe @statsd_metrics @permission_name("put") + @requires_json def put( # pylint: disable=arguments-differ self, pk: int, annotation_id: int ) -> Response: @@ -370,8 +374,6 @@ def put( # pylint: disable=arguments-differ 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) item["layer"] = pk diff --git a/superset/annotation_layers/api.py b/superset/annotation_layers/api.py index bba34e2fa8d79..db3979f663607 100644 --- a/superset/annotation_layers/api.py +++ b/superset/annotation_layers/api.py @@ -49,7 +49,11 @@ from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.extensions import event_logger from superset.models.annotations import AnnotationLayer -from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from superset.views.base_api import ( + BaseSupersetModelRestApi, + requires_json, + statsd_metrics, +) logger = logging.getLogger(__name__) @@ -171,6 +175,7 @@ def delete(self, pk: int) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Creates a new Annotation Layer --- @@ -205,8 +210,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -237,6 +240,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, pk: int) -> Response: """Updates an Annotation Layer --- @@ -277,8 +281,6 @@ def put(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) item["layer"] = pk diff --git a/superset/charts/api.py b/superset/charts/api.py index f2e7741372fdd..6b45900edba2c 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -75,6 +75,8 @@ from superset.views.base_api import ( BaseSupersetModelRestApi, RelatedFieldFilter, + requires_form_data, + requires_json, statsd_metrics, ) from superset.views.filters import FilterRelatedOwners @@ -239,6 +241,7 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Creates a new Chart --- @@ -273,8 +276,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -302,6 +303,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, pk: int) -> Response: """Changes a Chart --- @@ -345,8 +347,6 @@ def put(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -820,6 +820,7 @@ def favorite_status(self, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", log_to_statsd=False, ) + @requires_form_data def import_(self) -> Response: """Import chart(s) with associated datasets and databases --- diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index de7f8cca69ff0..ea05ce9cacdd0 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -81,6 +81,8 @@ from superset.views.base_api import ( BaseSupersetModelRestApi, RelatedFieldFilter, + requires_form_data, + requires_json, statsd_metrics, ) from superset.views.filters import FilterRelatedOwners @@ -430,6 +432,7 @@ def get_charts(self, id_or_slug: str) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Creates a new Dashboard --- @@ -466,8 +469,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -495,6 +496,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, pk: int) -> Response: """Changes a Dashboard --- @@ -540,8 +542,6 @@ def put(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -927,6 +927,7 @@ def favorite_status(self, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", log_to_statsd=False, ) + @requires_form_data def import_(self) -> Response: """Import dashboard(s) with associated charts/datasets/databases --- diff --git a/superset/dashboards/filter_sets/api.py b/superset/dashboards/filter_sets/api.py index d4efd40fa41c8..3dc2a28de260c 100644 --- a/superset/dashboards/filter_sets/api.py +++ b/superset/dashboards/filter_sets/api.py @@ -62,7 +62,11 @@ ) from superset.extensions import event_logger from superset.models.filter_set import FilterSet -from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from superset.views.base_api import ( + BaseSupersetModelRestApi, + requires_json, + statsd_metrics, +) logger = logging.getLogger(__name__) @@ -193,6 +197,7 @@ def get_list(self, dashboard_id: int, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self, dashboard_id: int) -> Response: """ Creates a new Dashboard's Filter Set @@ -236,8 +241,6 @@ def post(self, dashboard_id: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) new_model = CreateFilterSetCommand(g.user, dashboard_id, item).run() @@ -261,6 +264,7 @@ def post(self, dashboard_id: int) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, dashboard_id: int, pk: int) -> Response: """Changes a Dashboard's Filter set --- @@ -308,8 +312,6 @@ def put(self, dashboard_id: int, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) changed_model = UpdateFilterSetCommand(g.user, dashboard_id, pk, item).run() diff --git a/superset/databases/api.py b/superset/databases/api.py index b347503385a1b..79c4ff7460b41 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -68,12 +68,16 @@ from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import InvalidPayloadFormatError from superset.extensions import security_manager from superset.models.core import Database from superset.typing import FlaskResponse from superset.utils.core import error_msg_from_exception -from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from superset.views.base_api import ( + BaseSupersetModelRestApi, + requires_form_data, + requires_json, + statsd_metrics, +) logger = logging.getLogger(__name__) @@ -201,6 +205,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Creates a new Database --- @@ -237,9 +242,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -277,6 +279,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, pk: int) -> Response: """Changes a Database --- @@ -320,8 +323,6 @@ def put(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -593,6 +594,7 @@ def select_star( f".test_connection", log_to_statsd=False, ) + @requires_json def test_connection(self) -> FlaskResponse: """Tests a database connection --- @@ -623,8 +625,6 @@ def test_connection(self) -> FlaskResponse: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = DatabaseTestConnectionSchema().load(request.json) # This validates custom Schema with custom validations @@ -774,6 +774,7 @@ def export(self, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", log_to_statsd=False, ) + @requires_form_data def import_(self) -> Response: """Import database(s) with associated datasets --- @@ -985,6 +986,7 @@ def available(self) -> Response: f".validate_parameters", log_to_statsd=False, ) + @requires_json def validate_parameters(self) -> FlaskResponse: """validates database connection parameters --- @@ -1015,9 +1017,6 @@ def validate_parameters(self) -> FlaskResponse: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - raise InvalidPayloadFormatError("Request is not JSON") - try: payload = DatabaseValidateParametersSchema().load(request.json) except ValidationError as ex: diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 8a9d9051a8aa7..ce8ce55cc25b4 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -65,6 +65,8 @@ from superset.views.base_api import ( BaseSupersetModelRestApi, RelatedFieldFilter, + requires_form_data, + requires_json, statsd_metrics, ) from superset.views.filters import FilterRelatedOwners @@ -206,6 +208,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Creates a new Dataset --- @@ -240,8 +243,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -270,6 +271,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, pk: int) -> Response: """Changes a Dataset --- @@ -322,8 +324,6 @@ def put(self, pk: int) -> Response: if "override_columns" in request.args else False ) - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -680,6 +680,7 @@ def bulk_delete(self, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", log_to_statsd=False, ) + @requires_form_data def import_(self) -> Response: """Import dataset(s) with associated databases --- diff --git a/superset/explore/form_data/api.py b/superset/explore/form_data/api.py index 2ed7096768932..fe73ed0d3af92 100644 --- a/superset/explore/form_data/api.py +++ b/superset/explore/form_data/api.py @@ -30,7 +30,6 @@ DatasetAccessDeniedError, DatasetNotFoundError, ) -from superset.exceptions import InvalidPayloadFormatError from superset.explore.form_data.commands.create import CreateFormDataCommand from superset.explore.form_data.commands.delete import DeleteFormDataCommand from superset.explore.form_data.commands.get import GetFormDataCommand @@ -39,6 +38,7 @@ from superset.explore.form_data.schemas import FormDataPostSchema, FormDataPutSchema from superset.extensions import event_logger from superset.key_value.commands.exceptions import KeyValueAccessDeniedError +from superset.views.base_api import requires_json logger = logging.getLogger(__name__) @@ -66,6 +66,7 @@ class ExploreFormDataRestApi(BaseApi, ABC): action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", log_to_statsd=False, ) + @requires_json def post(self) -> Response: """Stores a new form_data. --- @@ -98,8 +99,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - raise InvalidPayloadFormatError("Request is not JSON") try: item = self.add_model_schema.load(request.json) args = CommandParameters( @@ -128,6 +127,7 @@ def post(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", log_to_statsd=False, ) + @requires_json def put(self, key: str) -> Response: """Updates an existing form_data. --- @@ -167,8 +167,6 @@ def put(self, key: str) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - raise InvalidPayloadFormatError("Request is not JSON") try: item = self.edit_model_schema.load(request.json) args = CommandParameters( diff --git a/superset/key_value/api.py b/superset/key_value/api.py index 85aa6e3746421..8f6a0d2d06364 100644 --- a/superset/key_value/api.py +++ b/superset/key_value/api.py @@ -37,10 +37,10 @@ DatasetAccessDeniedError, DatasetNotFoundError, ) -from superset.exceptions import InvalidPayloadFormatError from superset.key_value.commands.exceptions import KeyValueAccessDeniedError from superset.key_value.commands.parameters import CommandParameters from superset.key_value.schemas import KeyValuePostSchema, KeyValuePutSchema +from superset.views.base_api import requires_json logger = logging.getLogger(__name__) @@ -69,9 +69,8 @@ def add_apispec_components(self, api_spec: APISpec) -> None: pass super().add_apispec_components(api_spec) + @requires_json def post(self, pk: int) -> Response: - if not request.is_json: - raise InvalidPayloadFormatError("Request is not JSON") try: item = self.add_model_schema.load(request.json) args = CommandParameters( @@ -94,9 +93,8 @@ def post(self, pk: int) -> Response: except (ChartNotFoundError, DashboardNotFoundError, DatasetNotFoundError) as ex: return self.response(404, message=str(ex)) + @requires_json def put(self, pk: int, key: str) -> Response: - if not request.is_json: - raise InvalidPayloadFormatError("Request is not JSON") try: item = self.edit_model_schema.load(request.json) args = CommandParameters( diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index eb484cc94ccd9..a4d74cc5c0872 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -53,7 +53,11 @@ get_export_ids_schema, openapi_spec_methods_override, ) -from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from superset.views.base_api import ( + BaseSupersetModelRestApi, + requires_form_data, + statsd_metrics, +) logger = logging.getLogger(__name__) @@ -272,6 +276,7 @@ def export(self, **kwargs: Any) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_", log_to_statsd=False, ) + @requires_form_data def import_(self) -> Response: """Import Saved Queries with associated databases --- diff --git a/superset/reports/api.py b/superset/reports/api.py index 7d8d548eb1383..33559312e02e0 100644 --- a/superset/reports/api.py +++ b/superset/reports/api.py @@ -53,6 +53,7 @@ from superset.views.base_api import ( BaseSupersetModelRestApi, RelatedFieldFilter, + requires_json, statsd_metrics, ) from superset.views.filters import FilterRelatedOwners @@ -275,6 +276,7 @@ def delete(self, pk: int) -> Response: @protect() @statsd_metrics @permission_name("post") + @requires_json def post(self) -> Response: """Creates a new Report Schedule --- @@ -309,8 +311,6 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.add_model_schema.load(request.json) # This validates custom Schema with custom validations @@ -337,6 +337,7 @@ def post(self) -> Response: @safe @statsd_metrics @permission_name("put") + @requires_json def put(self, pk: int) -> Response: """Updates an Report Schedule --- @@ -379,8 +380,6 @@ def put(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - if not request.is_json: - return self.response_400(message="Request is not JSON") try: item = self.edit_model_schema.load(request.json) # This validates custom Schema with custom validations diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 8193a283b7dbb..87e99e7c74a7b 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -18,7 +18,7 @@ import logging from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union -from flask import Blueprint, g, Response +from flask import Blueprint, g, request, Response from flask_appbuilder import AppBuilder, Model, ModelRestApi from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.filters import BaseFilter, Filters @@ -29,6 +29,7 @@ from sqlalchemy import and_, distinct, func from sqlalchemy.orm.query import Query +from superset.exceptions import InvalidPayloadFormatError from superset.extensions import db, event_logger, security_manager from superset.models.core import FavStar from superset.models.dashboard import Dashboard @@ -70,6 +71,34 @@ class DistincResponseSchema(Schema): result = fields.List(fields.Nested(DistinctResultResponseSchema)) +def requires_json(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Require JSON-like formatted request to the REST API + """ + + def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response: + if not request.is_json: + raise InvalidPayloadFormatError(message="Request is not JSON") + return f(self, *args, **kwargs) + + return functools.update_wrapper(wraps, f) + + +def requires_form_data(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Require 'multipart/form-data' as request MIME type + """ + + def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response: + if not request.mimetype == "multipart/form-data": + raise InvalidPayloadFormatError( + message="Request MIME type is not 'multipart/form-data'" + ) + return f(self, *args, **kwargs) + + return functools.update_wrapper(wraps, f) + + def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]: """ Handle sending all statsd metrics from the REST API diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index 09a754e392382..8dbbb6862e3bf 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -29,7 +29,7 @@ from superset import db, security_manager from superset.extensions import appbuilder from superset.models.dashboard import Dashboard -from superset.views.base_api import BaseSupersetModelRestApi +from superset.views.base_api import BaseSupersetModelRestApi, requires_json from .base_tests import SupersetTestCase @@ -154,6 +154,19 @@ def test_default_missing_declaration_post(self): } self.assertEqual(response, expected_response) + def test_refuse_invalid_format_request(self): + """ + API: Test invalid format of request + + We want to make sure that non-JSON request are refused + """ + self.login(username="admin") + uri = "api/v1/report/" # endpoint decorated with @requires_json + rv = self.client.post( + uri, data="a: value\nb: 1\n", content_type="application/yaml" + ) + self.assertEqual(rv.status_code, 400) + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_default_missing_declaration_put(self): """