From 3a42071e0ff181a7a0f1b55a69e39440e2570018 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Thu, 9 Dec 2021 17:49:32 +0200 Subject: [PATCH] chore(sql): clean up invalid filter clause exception types (#17702) * chore(sql): clean up invalid filter clause exception types * fix lint * rename exception --- superset/common/query_object.py | 16 +++++- superset/exceptions.py | 4 ++ superset/sql_parse.py | 22 +++++++ superset/viz.py | 11 ++++ .../charts/data/api_tests.py | 22 +++++++ tests/unit_tests/sql_parse_tests.py | 57 ++++++++++++++++++- 6 files changed, 130 insertions(+), 2 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index ff1ad710ee40e..03ee9cb1d3059 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -25,7 +25,11 @@ from pandas import DataFrame from superset.common.chart_data import ChartDataResultType -from superset.exceptions import QueryObjectValidationError +from superset.exceptions import ( + QueryClauseValidationException, + QueryObjectValidationError, +) +from superset.sql_parse import validate_filter_clause from superset.typing import Column, Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( @@ -267,6 +271,7 @@ def validate( try: self._validate_there_are_no_missing_series() self._validate_no_have_duplicate_labels() + self._validate_filters() return None except QueryObjectValidationError as ex: if raise_exceptions: @@ -285,6 +290,15 @@ def _validate_no_have_duplicate_labels(self) -> None: ) ) + def _validate_filters(self) -> None: + for param in ("where", "having"): + clause = self.extras.get(param) + if clause: + try: + validate_filter_clause(clause) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + def _validate_there_are_no_missing_series(self) -> None: missing_series = [col for col in self.series_columns if col not in self.columns] if missing_series: diff --git a/superset/exceptions.py b/superset/exceptions.py index 76da484dc3f95..2a902608a6a97 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -194,6 +194,10 @@ class CacheLoadError(SupersetException): status = 404 +class QueryClauseValidationException(SupersetException): + status = 400 + + class DashboardImportException(SupersetException): pass diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 47a9e5cb2d47e..1130763a372c0 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -32,6 +32,8 @@ from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt +from superset.exceptions import QueryClauseValidationException + RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} ON_KEYWORD = "ON" PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} @@ -378,3 +380,23 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: for i in statement.tokens: str_res += str(i.value) return str_res + + +def validate_filter_clause(clause: str) -> None: + if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause): + raise QueryClauseValidationException("Filter clause contains comment") + + statements = sqlparse.parse(clause) + if len(statements) != 1: + raise QueryClauseValidationException("Filter clause contains multiple queries") + open_parens = 0 + + for token in statements[0]: + if token.value in (")", "("): + open_parens += 1 if token.value == "(" else -1 + if open_parens < 0: + raise QueryClauseValidationException( + "Closing unclosed parenthesis in filter clause" + ) + if open_parens > 0: + raise QueryClauseValidationException("Unclosed parenthesis in filter clause") diff --git a/superset/viz.py b/superset/viz.py index 53bc333224160..23f2cf336d8d4 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -62,12 +62,14 @@ from superset.exceptions import ( CacheLoadError, NullValueException, + QueryClauseValidationException, QueryObjectValidationError, SpatialException, SupersetSecurityException, ) from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult +from superset.sql_parse import validate_filter_clause from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache @@ -373,6 +375,15 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals self.from_dttm = from_dttm self.to_dttm = to_dttm + # validate sql filters + for param in ("where", "having"): + clause = self.form_data.get(param) + if clause: + try: + validate_filter_clause(clause) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + # extras are used to query elements specific to a datasource type # for instance the extra where clause that applies only to Tables extras = { diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 12d667f1611b8..cf6d0b537f145 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -425,6 +425,28 @@ def test_with_invalid_where_parameter__400(self): assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_invalid_where_parameter_closing_unclosed__400(self): + self.query_context_payload["queries"][0]["filters"] = [] + self.query_context_payload["queries"][0]["extras"][ + "where" + ] = "state = 'CA') OR (state = 'NY'" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_invalid_having_parameter_closing_and_comment__400(self): + self.query_context_payload["queries"][0]["filters"] = [] + self.query_context_payload["queries"][0]["extras"][ + "having" + ] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + def test_with_invalid_datasource__400(self): self.query_context_payload["datasource"] = "abc" diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 61ea2e0a32171..f405b9fcda426 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -20,9 +20,16 @@ import unittest from typing import Set +import pytest import sqlparse -from superset.sql_parse import ParsedQuery, strip_comments_from_sql, Table +from superset.exceptions import QueryClauseValidationException +from superset.sql_parse import ( + ParsedQuery, + strip_comments_from_sql, + Table, + validate_filter_clause, +) def extract_tables(query: str) -> Set[Table]: @@ -1144,3 +1151,51 @@ def test_strip_comments_from_sql() -> None: strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n") == "SELECT '--abc' as abc, col2 FROM table1" ) + + +def test_validate_filter_clause_valid(): + # regular clauses + assert validate_filter_clause("col = 1") is None + assert validate_filter_clause("1=\t\n1") is None + assert validate_filter_clause("(col = 1)") is None + assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None + + # Valid literal values that appear to be invalid + assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None + assert validate_filter_clause("col = 'select 1; select 2'") is None + assert validate_filter_clause("col = 'abc -- comment'") is None + + +def test_validate_filter_clause_closing_unclosed(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("col1 = 1) AND (col2 = 2)") + + +def test_validate_filter_clause_unclosed(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("(col1 = 1) AND (col2 = 2") + + +def test_validate_filter_clause_closing_and_unclosed(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("col1 = 1) AND (col2 = 2") + + +def test_validate_filter_clause_closing_and_unclosed_nested(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("(col1 = 1)) AND ((col2 = 2)") + + +def test_validate_filter_clause_multiple(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("TRUE; SELECT 1") + + +def test_validate_filter_clause_comment(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("1 = 1 -- comment") + + +def test_validate_filter_clause_subquery_comment(): + with pytest.raises(QueryClauseValidationException): + validate_filter_clause("(1 = 1 -- comment\n)")