From 4ab7bfa972f4a655526590554052ba8506478b64 Mon Sep 17 00:00:00 2001 From: timifasubaa <30888507+timifasubaa@users.noreply.github.com> Date: Thu, 23 Aug 2018 11:20:25 -0700 Subject: [PATCH] [sqlparse] fix sqlparse bug (#5703) * fix sqlparse bug * add one more test case (cherry picked from commit 5c49514e2f18ac18b16e3befe9c4b9b900b49908) --- superset/sql_parse.py | 4 +++- tests/sql_parse_tests.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index ae33453b25b87..ccef3505b46b3 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -12,6 +12,7 @@ from sqlparse.tokens import Keyword, Name RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'} +ON_KEYWORD = 'ON' PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} @@ -128,7 +129,8 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if self.__is_result_operation(item.value): + if (self.__is_result_operation(item.value) or + item.value.upper() == ON_KEYWORD): table_name_preceding_token = False continue # FROM clause is over diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index 71ee29406ad90..5306760a49fc0 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -310,3 +310,31 @@ def test_explain(self): self.assertEquals(True, sql.is_explain()) self.assertEquals(False, sql.is_select()) self.assertEquals(True, sql.is_readonly()) + + def test_complex_extract_tables(self): + query = """SELECT sum(m_examples) AS "sum__m_example" + FROM + (SELECT COUNT(DISTINCT id_userid) AS m_examples, + some_more_info + FROM my_b_table b + JOIN my_t_table t ON b.ds=t.ds + JOIN my_l_table l ON b.uid=l.uid + WHERE b.rid IN + (SELECT other_col + FROM inner_table) + AND l.bla IN ('x', 'y') + GROUP BY 2 + ORDER BY 2 ASC) AS "meh" + ORDER BY "sum__m_example" DESC + LIMIT 10;""" + self.assertEquals( + {'my_l_table', 'my_b_table', 'my_t_table', 'inner_table'}, + self.extract_tables(query)) + + def test_complex_extract_tables2(self): + query = """SELECT * + FROM table_a AS a, table_b AS b, table_c as c + WHERE a.id = b.id and b.id = c.id""" + self.assertEquals( + {'table_a', 'table_b', 'table_c'}, + self.extract_tables(query))