From 01125f81a10998a3e8470cf2a3e7aa817d2efc6a Mon Sep 17 00:00:00 2001 From: Ofeknielsen Date: Wed, 25 Aug 2021 12:41:37 +0300 Subject: [PATCH 1/4] refactor sql_json view endpoint --- superset/utils/sqllab_execution_context.py | 115 +++++++++++++++++++++ superset/views/core.py | 90 +++++++--------- 2 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 superset/utils/sqllab_execution_context.py diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py new file mode 100644 index 0000000000000..325848d7cb779 --- /dev/null +++ b/superset/utils/sqllab_execution_context.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import Any, cast, Dict, Optional, Type, TYPE_CHECKING + +from flask import g + +from superset import app, is_feature_enabled +from superset.sql_parse import CtasMethod +from superset.utils import core as utils +from superset.utils.dates import now_as_float + +if TYPE_CHECKING: + from superset.models.sql_lab import Query + +QueryStatus = utils.QueryStatus +logger = logging.getLogger(__name__) + +SqlResults = Dict[str, Any] + + +@dataclass +class SqlJsonExecutionContext: + database_id: int + schema: str + sql: str + template_params: Dict[str, Any] + async_flag: bool + limit: int + status: str + select_as_cta: bool + ctas_method: CtasMethod + tmp_table_name: str + client_id: str + client_id_or_short_id: str + sql_editor_id: str + tab_name: str + user_id: Optional[int] + expand_data: bool + + def __init__(self, query_params: Dict[str, Any]): + self._init_from_query_params(query_params) + self.user_id = self._get_user_id() + self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10]) + + def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: + self.database_id = cast(int, query_params.get("database_id")) + self.schema = cast(str, query_params.get("schema")) + self.sql = cast(str, query_params.get("sql")) + self.template_params = self._get_template_params(query_params) + self.async_flag = cast(bool, query_params.get("runAsync")) + self.limit = self._get_limit_param(query_params) + self.status = cast(str, query_params.get("status")) + self.select_as_cta = cast(bool, query_params.get("select_as_cta")) + self.ctas_method = cast( + CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE) + ) + self.tmp_table_name = cast(str, query_params.get("tmp_table_name")) + self.client_id = cast(str, query_params.get("client_id")) + self.sql_editor_id = cast(str, query_params.get("sql_editor_id")) + self.tab_name = cast(str, query_params.get("tab")) + self.expand_data: bool = cast( + bool, + is_feature_enabled("PRESTO_EXPAND_DATA") + and query_params.get("expand_data"), + ) + + @staticmethod + def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]: + try: + template_params = json.loads(query_params.get("templateParams") or "{}") + except json.JSONDecodeError: + logger.warning( + "Invalid template parameter %s" " specified. Defaulting to empty dict", + str(query_params.get("templateParams")), + ) + template_params = {} + return template_params + + @staticmethod + def _get_limit_param(query_params: Dict[str, Any]) -> int: + limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"] + if limit < 0: + logger.warning( + "Invalid limit of %i specified. Defaulting to max limit.", limit + ) + limit = 0 + return limit + + def _get_user_id(self) -> Optional[int]: + try: + return g.user.get_id() if g.user else None + except RuntimeError: + return None + + def is_should_run_asynchronous(self) -> bool: + return self.async_flag diff --git a/superset/views/core.py b/superset/views/core.py index f7c0281aef42c..e8070d347b37f 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -110,6 +110,7 @@ from superset.utils.core import ReservedUrlParameters from superset.utils.dates import now_as_float from superset.utils.decorators import check_dashboard_access +from superset.utils.sqllab_execution_context import SqlJsonExecutionContext from superset.views.base import ( api, BaseSupersetView, @@ -2577,42 +2578,16 @@ def sql_json(self) -> FlaskResponse: log_params = { "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) } - return self.sql_json_exec(request.json, log_params) + execution_context = SqlJsonExecutionContext(request.json) + return self.sql_json_exec(execution_context, request.json, log_params) def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals - self, query_params: Dict[str, Any], log_params: Optional[Dict[str, Any]] = None + self, + execution_context: SqlJsonExecutionContext, + query_params: Dict[str, Any], + log_params: Optional[Dict[str, Any]] = None, ) -> FlaskResponse: """Runs arbitrary sql and returns data as json""" - # Collect Values - database_id: int = cast(int, query_params.get("database_id")) - schema: str = cast(str, query_params.get("schema")) - sql: str = cast(str, query_params.get("sql")) - try: - template_params = json.loads(query_params.get("templateParams") or "{}") - except json.JSONDecodeError: - logger.warning( - "Invalid template parameter %s" " specified. Defaulting to empty dict", - str(query_params.get("templateParams")), - ) - template_params = {} - limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"] - async_flag: bool = cast(bool, query_params.get("runAsync")) - if limit < 0: - logger.warning( - "Invalid limit of %i specified. Defaulting to max limit.", limit - ) - limit = 0 - select_as_cta: bool = cast(bool, query_params.get("select_as_cta")) - ctas_method: CtasMethod = cast( - CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE) - ) - tmp_table_name: str = cast(str, query_params.get("tmp_table_name")) - client_id: str = cast(str, query_params.get("client_id")) - client_id_or_short_id: str = cast(str, client_id or utils.shortid()[:10]) - sql_editor_id: str = cast(str, query_params.get("sql_editor_id")) - tab_name: str = cast(str, query_params.get("tab")) - status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING - user_id: int = g.user.get_id() if g.user else None session = db.session() @@ -2620,7 +2595,9 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals query = ( session.query(Query) .filter_by( - client_id=client_id, user_id=user_id, sql_editor_id=sql_editor_id + client_id=execution_context.client_id, + user_id=execution_context.user_id, + sql_editor_id=execution_context.sql_editor_id, ) .one_or_none() ) @@ -2635,7 +2612,7 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals ) return json_success(payload) - mydb = session.query(Database).get(database_id) + mydb = session.query(Database).get(execution_context.database_id) if not mydb: raise SupersetGenericErrorException( __( @@ -2648,27 +2625,29 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals # TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from # tmp_table_name if user enters # . - tmp_schema_name: Optional[str] = schema - if select_as_cta and mydb.force_ctas_schema: + tmp_schema_name: Optional[str] = execution_context.schema + if execution_context.select_as_cta and mydb.force_ctas_schema: tmp_schema_name = mydb.force_ctas_schema - elif select_as_cta: - tmp_schema_name = get_cta_schema_name(mydb, g.user, schema, sql) + elif execution_context.select_as_cta: + tmp_schema_name = get_cta_schema_name( + mydb, g.user, execution_context.schema, execution_context.sql + ) # Save current query query = Query( - database_id=database_id, - sql=sql, - schema=schema, - select_as_cta=select_as_cta, - ctas_method=ctas_method, + database_id=execution_context.database_id, + sql=execution_context.sql, + schema=execution_context.schema, + select_as_cta=execution_context.select_as_cta, + ctas_method=execution_context.ctas_method, start_time=now_as_float(), - tab_name=tab_name, - status=status, - sql_editor_id=sql_editor_id, - tmp_table_name=tmp_table_name, + tab_name=execution_context.tab_name, + status=execution_context.status, + sql_editor_id=execution_context.sql_editor_id, + tmp_table_name=execution_context.tmp_table_name, tmp_schema_name=tmp_schema_name, - user_id=user_id, - client_id=client_id_or_short_id, + user_id=execution_context.user_id, + client_id=execution_context.client_id_or_short_id, ) try: session.add(query) @@ -2703,7 +2682,7 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals database=query.database, query=query ) rendered_query = template_processor.process_template( - query.sql, **template_params + query.sql, **execution_context.template_params ) except TemplateError as ex: query.status = QueryStatus.FAILED @@ -2734,15 +2713,18 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR, extra={ "undefined_parameters": list(undefined_parameters), - "template_parameters": template_params, + "template_parameters": execution_context.template_params, }, ) # Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set # to True. - if not (config.get("SQLLAB_CTAS_NO_LIMIT") and select_as_cta): + if not (config.get("SQLLAB_CTAS_NO_LIMIT") and execution_context.select_as_cta): # set LIMIT after template processing - limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit] + limits = [ + mydb.db_engine_spec.get_limit_from_sql(rendered_query), + execution_context.limit, + ] if limits[0] is None or limits[0] > limits[1]: query.limiting_factor = LimitingFactor.DROPDOWN elif limits[1] > limits[0]: @@ -2760,7 +2742,7 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals ) # Async request. - if async_flag: + if execution_context.is_should_run_asynchronous(): return self._sql_json_async( session, rendered_query, query, expand_data, log_params ) From 2da5a43dd3e1f8c4e2b029d450c1c23a7ebecfd0 Mon Sep 17 00:00:00 2001 From: Ofeknielsen Date: Wed, 25 Aug 2021 13:02:01 +0300 Subject: [PATCH 2/4] fix pylint --- superset/utils/sqllab_execution_context.py | 10 +++------- superset/views/core.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 325848d7cb779..0fd5fa8437c93 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -19,17 +19,13 @@ import json import logging from dataclasses import dataclass -from typing import Any, cast, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, cast, Dict, Optional from flask import g from superset import app, is_feature_enabled from superset.sql_parse import CtasMethod from superset.utils import core as utils -from superset.utils.dates import now_as_float - -if TYPE_CHECKING: - from superset.models.sql_lab import Query QueryStatus = utils.QueryStatus logger = logging.getLogger(__name__) @@ -37,7 +33,7 @@ SqlResults = Dict[str, Any] -@dataclass +@dataclass # pylint: disable=R0902 class SqlJsonExecutionContext: database_id: int schema: str @@ -105,7 +101,7 @@ def _get_limit_param(query_params: Dict[str, Any]) -> int: limit = 0 return limit - def _get_user_id(self) -> Optional[int]: + def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201 try: return g.user.get_id() if g.user else None except RuntimeError: diff --git a/superset/views/core.py b/superset/views/core.py index e8070d347b37f..6ebe83f6ffc52 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=comparison-with-callable, line-too-long, too-many-branches +# pylint: disable=comparison-with-callable, line-too-long import dataclasses import logging import re @@ -100,7 +100,7 @@ from superset.models.user_attributes import UserAttribute from superset.queries.dao import QueryDAO from superset.security.analytics_db_safety import check_sqlalchemy_uri -from superset.sql_parse import CtasMethod, ParsedQuery, Table +from superset.sql_parse import ParsedQuery, Table from superset.sql_validators import get_validator_by_name from superset.tasks.async_queries import load_explore_json_into_cache from superset.typing import FlaskResponse From 1c9a603a29807e3ebcf9116eb8e6861d7d3985c9 Mon Sep 17 00:00:00 2001 From: Amit Miran <47772523+amitmiran137@users.noreply.github.com> Date: Wed, 25 Aug 2021 14:40:09 +0300 Subject: [PATCH 3/4] renaming --- superset/views/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/views/core.py b/superset/views/core.py index 6ebe83f6ffc52..c5381129376d9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2742,7 +2742,7 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals ) # Async request. - if execution_context.is_should_run_asynchronous(): + if execution_context.is_run_asynchronous(): return self._sql_json_async( session, rendered_query, query, expand_data, log_params ) From 4b85e296d06c26a35ea819257978dc016ce5065d Mon Sep 17 00:00:00 2001 From: Amit Miran <47772523+amitmiran137@users.noreply.github.com> Date: Wed, 25 Aug 2021 14:40:19 +0300 Subject: [PATCH 4/4] renaming --- superset/utils/sqllab_execution_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 0fd5fa8437c93..fc4655201347a 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -107,5 +107,5 @@ def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201 except RuntimeError: return None - def is_should_run_asynchronous(self) -> bool: + def is_run_asynchronous(self) -> bool: return self.async_flag