Skip to content

Commit

Permalink
fixes to td.py table ref4
Browse files Browse the repository at this point in the history
  • Loading branch information
mccushjack committed Nov 5, 2021
1 parent 830ca55 commit c4a4898
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ The connection string for Teradata looks like this ::

The recommended connector library is
[teradatasql](https://github.com/Teradata/python-driver).
Also, see the latest on [PyPi](https://pypi.org/project/teradatasql/)
Also, see the latest on [PyPi](https://pypi.org/project/teradatasql/)

The connection string for Teradata looks like this:

Expand Down
1 change: 0 additions & 1 deletion docs/src/pages/docs/Connecting to Databases/teradata.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ The connection string for Teradata looks like this:
```
teradatasql://{user}:{password}@{host}
```

145 changes: 128 additions & 17 deletions superset/db_engine_specs/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from typing import List, Optional, Set
from urllib import parse

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod

import sqlparse
from sqlparse.sql import (
Identifier,
Expand All @@ -34,19 +32,18 @@
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

"""
from typing import List, Optional, TYPE_CHECKING
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import Table

PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database
"""

def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]:
td_limit_keywork = set(["TOP", "SAMPLE"])
str_statement = str(statement)
str_statement = str_statement.replace('\n', ' ').replace('\r', '')
token = str(str_statement).rstrip().split(' ')
str_statement = str_statement.replace("\n", " ").replace("\r", "")
token = str(str_statement).rstrip().split(" ")
token = list(filter(None, token))
limit = None

Expand All @@ -62,7 +59,10 @@ def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]:


class ParsedQuery_td:
def __init__(self, sql_statement: str, strip_comments: bool = False, uri_type: str = None):
def __init__(
self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None"
):

if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

Expand All @@ -76,10 +76,120 @@ def __init__(self, sql_statement: str, strip_comments: bool = False, uri_type: s
for statement in self._parsed:
self._limit = _extract_limit_from_query_td(statement)

@property
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables

def stripped(self) -> str:
return self.sql.strip(" \t\n;")

def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue

if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)

@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
:param tlist: The SQL tokens
:returns: The table if the name conforms
"""

# Strip the alias if present.
idx = len(tlist.tokens)

if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)

if ws_idx != -1:
idx = ws_idx

tokens = tlist.tokens[:idx]

if (
len(tokens) in (1, 3, 5)
and all(imt(token, t=[Name, String]) for token in tokens[::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
):
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])

return None

@staticmethod
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)

def set_or_update_query_limit_td(self, new_limit: int) -> str:
td_sel_keywork = set(["SELECT ", "SEL "])
Expand All @@ -94,8 +204,8 @@ def set_or_update_query_limit_td(self, new_limit: int) -> str:
final_limit = self._limit

str_statement = str(statement)
str_statement = str_statement.replace('\n', ' ').replace('\r', '')
tokens = str(str_statement).rstrip().split(' ')
str_statement = str_statement.replace("\n", " ").replace("\r", "")
tokens = str(str_statement).rstrip().split(" ")
tokens = list(filter(None, tokens))

next_remove_ind = False
Expand All @@ -117,7 +227,6 @@ def set_or_update_query_limit_td(self, new_limit: int) -> str:
return str_res



class TeradataEngineSpec(BaseEngineSpec):
"""Dialect for Teradata DB."""

Expand Down Expand Up @@ -146,7 +255,9 @@ def epoch_to_dttm(cls) -> str:
)

@classmethod
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database", force: bool = False) -> str:
def apply_limit_to_sql(
cls, sql: str, limit: int, database: str = "Database", force: bool = False
) -> str:
"""
Alters the SQL statement to apply a TOP clause
The function overwrites similar function in base.py because Teradata doesn't support LIMIT syntax
Expand All @@ -155,8 +266,8 @@ def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database", force: b
:param database: Database instance
:return: SQL query with limit clause
"""

parsed_query = ParsedQuery_td(sql)
sql = parsed_query.set_or_update_query_limit_td(limit)

return sql
41 changes: 41 additions & 0 deletions tests/unit_tests/db_engine_specs/test_teradata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access

from flask.ctx import AppContext
from pytest_mock import MockFixture

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod


def test_ParsedQuery_tds(app_context: AppContext) -> None:
"""
Test the custom ``ParsedQuery_td`` that calls ``_extract_limit_from_query_td(``
The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in
other dialects. and
"""
from uperset.db_engine_specs.teradata.TeradataEngineSpec import apply_limit_to_sql

from superset.db_engine_specs.teradata import ParsedQuery_td, TeradataEngineSpec

sql2 = "SEL TOP 1000 * FROM My_table;"
limit = 100

assert str(apply_limit_to_sql("teradata", sql2, limit, "Database")) == (
"SEL TOP 100 * FROM My_table "
)

0 comments on commit c4a4898

Please sign in to comment.