diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 620ac4ee971cf..56766749bcecc 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -41,8 +41,11 @@ from flask_babel import lazy_gettext as _ import pandas import sqlalchemy as sqla -from sqlalchemy import Column, select +from sqlalchemy import Column, select, types from sqlalchemy.engine import create_engine +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.url import make_url from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import TextAsFrom @@ -52,6 +55,7 @@ from superset import app, conf, db, sql_parse from superset.exceptions import SupersetTemplateException +from superset.models.sql_types.presto_sql_types import type_map as presto_type_map from superset.utils import core as utils QueryStatus = utils.QueryStatus @@ -105,7 +109,7 @@ class BaseEngineSpec(object): """Abstract class for database engine specific configurations""" engine = 'base' # str as defined in sqlalchemy.engine.engine - time_grain_functions = {} + time_grain_functions: dict = {} time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -113,8 +117,8 @@ class BaseEngineSpec(object): allows_subquery = True supports_column_aliases = True force_column_alias_quotes = False - arraysize = None - max_column_name_length = None + arraysize = 0 + max_column_name_length = 0 @classmethod def get_time_expr(cls, expr, pdf, time_grain, grain): @@ -351,6 +355,10 @@ def get_table_names(cls, inspector, schema): def get_view_names(cls, inspector, schema): return sorted(inspector.get_view_names(schema)) + @classmethod + def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: + return inspector.get_columns(table_name, schema) + @classmethod def where_latest_partition( cls, table_name, schema, database, qry, columns=None): @@ -735,7 +743,7 @@ class MySQLEngineSpec(BaseEngineSpec): 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', } - type_code_map = {} # loaded from get_datatype only if needed + type_code_map: dict = {} # loaded from get_datatype only if needed @classmethod def convert_dttm(cls, target_type, dttm): @@ -814,6 +822,178 @@ def get_view_names(cls, inspector, schema): """ return [] + @classmethod + def _create_column_info(cls, column: RowProxy, name: str, data_type: str) -> dict: + """ + Create column info object + :param column: column object + :param name: column name + :param data_type: column data type + :return: column info object + """ + return { + 'name': name, + 'type': data_type, + # newer Presto no longer includes this column + 'nullable': getattr(column, 'Null', True), + 'default': None, + } + + @classmethod + def _get_full_name(cls, names: list) -> str: + """ + Get the full column name + :param names: list of all individual column names + :return: full column name + """ + return '.'.join(row_type[0] for row_type in names if row_type[0] is not None) + + @classmethod + def _has_nested_data_types(cls, component_type: str) -> bool: + """ + Check if string contains a data type. We determine if there is a data type by + whitespace or multiple data types by commas + :param component_type: data type + :return: boolean + """ + comma_regex = r',(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' + white_space_regex = r'\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' + return re.search(comma_regex, component_type) is not None \ + or re.search(white_space_regex, component_type) is not None + + @classmethod + def _split_data_type(cls, data_type: str, delimiter: str) -> list: + """ + Split data type based on given delimiter. Do not split the string if the + delimiter is enclosed in quotes + :param data_type: data type + :param delimiter: string separator (i.e. open parenthesis, closed parenthesis, + comma, whitespace) + :return:list of strings after breaking it by the delimiter + """ + return re.split( + r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type) + + @classmethod + def _parse_structural_column(cls, column: RowProxy, result: list) -> None: + """ + Parse a row or array column + :param column: column + :param result: list tracking the results + """ + full_data_type = '{} {}'.format(column.Column, column.Type) + # split on open parenthesis ( to get the structural + # data type and its component types + data_types = cls._split_data_type(full_data_type, r'\(') + stack: list = [] + for data_type in data_types: + # split on closed parenthesis ) to track which component + # types belong to what structural data type + inner_types = cls._split_data_type(data_type, r'\)') + for inner_type in inner_types: + # We have finished parsing multiple structural data types + if not inner_type and len(stack) > 0: + stack.pop() + elif cls._has_nested_data_types(inner_type): + # split on comma , to get individual data types + single_fields = cls._split_data_type(inner_type, ', ') + for single_field in single_fields: + # If component type starts with a comma, the first single field + # will be an empty string. Disregard this empty string. + if not single_field: + continue + # split on whitespace to get field name and data type + field_info = cls._split_data_type(single_field, r'\s') + # check if there is a structural data type within + # overall structural data type + if field_info[1] == 'array' or field_info[1] == 'row': + stack.append((field_info[0], field_info[1])) + full_parent_path = cls._get_full_name(stack) + result.append(cls._create_column_info( + column, full_parent_path, + presto_type_map[field_info[1]]())) + else: # otherwise this field is a basic data type + full_parent_path = cls._get_full_name(stack) + column_name = '{}.{}'.format(full_parent_path, field_info[0]) + result.append(cls._create_column_info( + column, column_name, presto_type_map[field_info[1]]())) + # If the component type ends with a structural data type, do not pop + # the stack. We have run across a structural data type within the + # overall structural data type. Otherwise, we have completely parsed + # through the entire structural data type and can move on. + if not (inner_type.endswith('array') or inner_type.endswith('row')): + stack.pop() + # We have an array of row objects (i.e. array(row(...))) + elif 'array' == inner_type or 'row' == inner_type: + # Push a dummy object to represent the structural data type + stack.append((None, inner_type)) + # We have an array of a basic data types(i.e. array(varchar)). + elif len(stack) > 0: + # Because it is an array of a basic data type. We have finished + # parsing the structural data type and can move on. + stack.pop() + + @classmethod + def _show_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: + """ + Show presto column names + :param inspector: object that performs database schema inspection + :param table_name: table name + :param schema: schema name + :return: list of column objects + """ + quote = inspector.engine.dialect.identifier_preparer.quote_identifier + full_table = quote(table_name) + if schema: + full_table = '{}.{}'.format(quote(schema), full_table) + columns = inspector.bind.execute('SHOW COLUMNS FROM {}'.format(full_table)) + return columns + + @classmethod + def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: + """ + Get columns from a Presto data source. This includes handling row and + array data types + :param inspector: object that performs database schema inspection + :param table_name: table name + :param schema: schema name + :return: a list of results that contain column info + (i.e. column name and data type) + """ + columns = cls._show_columns(inspector, table_name, schema) + result: list = [] + for column in columns: + try: + # parse column if it is a row or array + if 'array' in column.Type or 'row' in column.Type: + cls._parse_structural_column(column, result) + continue + else: # otherwise column is a basic data type + column_type = presto_type_map[column.Type]() + except KeyError: + print('Did not recognize type {} of column {}'.format( + column.Type, column.Column)) + column_type = types.NullType + result.append(cls._create_column_info(column, column.Column, column_type)) + return result + + @classmethod + def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None, + limit: int = 100, show_cols: bool = False, indent: bool = True, + latest_partition: bool = True, cols: list = []) -> str: + """ + Temporary method until we have a function that can handle row and array columns + """ + presto_cols = cols + if show_cols: + dot_regex = r'\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' + presto_cols = [ + col for col in presto_cols if re.search(dot_regex, col['name']) is None] + return super(PrestoEngineSpec, cls).select_star( + my_db, table_name, engine, schema, limit, + show_cols, indent, latest_partition, presto_cols, + ) + @classmethod def adjust_database_uri(cls, uri, selected_schema=None): database = uri.database @@ -1323,6 +1503,10 @@ def handle_cursor(cls, cursor, query, session): time.sleep(hive_poll_interval) polled = cursor.poll() + @classmethod + def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: + return inspector.get_columns(table_name, schema) + @classmethod def where_latest_partition( cls, table_name, schema, database, qry, columns=None): diff --git a/superset/models/core.py b/superset/models/core.py index fb5850c0fa1b0..e16a234bfd723 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1075,7 +1075,7 @@ def get_table(self, table_name, schema=None): autoload_with=self.get_sqla_engine()) def get_columns(self, table_name, schema=None): - return self.inspector.get_columns(table_name, schema) + return self.db_engine_spec.get_columns(self.inspector, table_name, schema) def get_indexes(self, table_name, schema=None): return self.inspector.get_indexes(table_name, schema) diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py new file mode 100644 index 0000000000000..021c15cffa51e --- /dev/null +++ b/superset/models/sql_types/presto_sql_types.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from sqlalchemy import types +from sqlalchemy.sql.sqltypes import Integer +from sqlalchemy.sql.type_api import TypeEngine + + +# _compiler_dispatch is defined to help with type compilation + +class TinyInteger(Integer): + """ + A type for tiny ``int`` integers. + """ + def _compiler_dispatch(self, visitor, **kw): + return 'TINYINT' + + +class Interval(TypeEngine): + """ + A type for intervals. + """ + def _compiler_dispatch(self, visitor, **kw): + return 'INTERVAL' + + +class Array(TypeEngine): + + """ + A type for arrays. + """ + def _compiler_dispatch(self, visitor, **kw): + return 'ARRAY' + + +class Map(TypeEngine): + + """ + A type for maps. + """ + def _compiler_dispatch(self, visitor, **kw): + return 'MAP' + + +class Row(TypeEngine): + + """ + A type for rows. + """ + def _compiler_dispatch(self, visitor, **kw): + return 'ROW' + + +type_map = { + 'boolean': types.Boolean, + 'tinyint': TinyInteger, + 'smallint': types.SmallInteger, + 'integer': types.Integer, + 'bigint': types.BigInteger, + 'real': types.Float, + 'double': types.Float, + 'decimal': types.DECIMAL, + 'varchar': types.String, + 'char': types.CHAR, + 'varbinary': types.VARBINARY, + 'JSON': types.JSON, + 'date': types.DATE, + 'time': types.Time, + 'timestamp': types.TIMESTAMP, + 'interval': Interval, + 'array': Array, + 'map': Map, + 'row': Row, +} diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index e1286076bb623..ef9d6bc17da1a 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -19,6 +19,7 @@ from sqlalchemy import column, select, table from sqlalchemy.dialects.mssql import pymssql +from sqlalchemy.engine.result import RowProxy from sqlalchemy.types import String, UnicodeText from superset import db_engine_specs @@ -322,6 +323,66 @@ def test_engine_time_grain_validity(self): def test_presto_get_view_names_return_empty_list(self): self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY)) + def verify_presto_column(self, column, expected_results): + inspector = mock.Mock() + inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock() + keymap = {'Column': (None, None, 0), + 'Type': (None, None, 1), + 'Null': (None, None, 2)} + row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap) + inspector.bind.execute = mock.Mock(return_value=[row]) + results = PrestoEngineSpec.get_columns(inspector, '', '') + self.assertEqual(len(expected_results), len(results)) + for expected_result, result in zip(expected_results, results): + self.assertEqual(expected_result[0], result['name']) + self.assertEqual(expected_result[1], str(result['type'])) + + def test_presto_get_column(self): + presto_column = ('column_name', 'boolean', '') + expected_results = [('column_name', 'BOOLEAN')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_simple_row_column(self): + presto_column = ('column_name', 'row(nested_obj double)', '') + expected_results = [ + ('column_name', 'ROW'), + ('column_name.nested_obj', 'FLOAT')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_simple_row_column_with_tricky_name(self): + presto_column = ('column_name', 'row("Field Name(Tricky, Name)" double)', '') + expected_results = [ + ('column_name', 'ROW'), + ('column_name."Field Name(Tricky, Name)"', 'FLOAT')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_simple_array_column(self): + presto_column = ('column_name', 'array(double)', '') + expected_results = [('column_name', 'ARRAY')] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_row_within_array_within_row_column(self): + presto_column = ( + 'column_name', + 'row(nested_array array(row(nested_row double)), nested_obj double)', '') + expected_results = [ + ('column_name', 'ROW'), + ('column_name.nested_array', 'ARRAY'), + ('column_name.nested_array.nested_row', 'FLOAT'), + ('column_name.nested_obj', 'FLOAT'), + ] + self.verify_presto_column(presto_column, expected_results) + + def test_presto_get_array_within_row_within_array_column(self): + presto_column = ( + 'column_name', + 'array(row(nested_array array(double), nested_obj double))', '') + expected_results = [ + ('column_name', 'ARRAY'), + ('column_name.nested_array', 'ARRAY'), + ('column_name.nested_obj', 'FLOAT')] + self.verify_presto_column(presto_column, expected_results) + def test_hive_get_view_names_return_empty_list(self): self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))