From c7e20720f16992079ffc91cc5ff3780e21e03fb2 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 27 Jun 2018 21:35:12 -0700 Subject: [PATCH] Improve database type inference (#4724) * Improve database type inference Python's DBAPI isn't super clear and homogeneous on the cursor.description specification, and this PR attempts to improve inferring the datatypes returned in the cursor. This work started around Presto's TIMESTAMP type being mishandled as string as the database driver (pyhive) returns it as a string. The work here fixes this bug and does a better job at inferring MySQL and Presto types. It also creates a new method in db_engine_specs allowing for other databases engines to implement and become more precise on type-inference as needed. * Fixing tests * Adressing comments * Using infer_objects * Removing faulty line * Addressing PrestoSpec redundant method comment * Fix rebase issue * Fix tests --- superset/dataframe.py | 78 ++++++++++++++++++----- superset/db_engine_specs.py | 24 +++++++ superset/sql_lab.py | 44 +------------ tests/celery_tests.py | 51 +-------------- tests/core_tests.py | 4 +- tests/dataframe_test.py | 115 ++++++++++++++++++++++++++++++++++ tests/db_engine_specs_test.py | 10 ++- tests/sqllab_tests.py | 15 +++-- 8 files changed, 224 insertions(+), 117 deletions(-) create mode 100644 tests/dataframe_test.py diff --git a/superset/dataframe.py b/superset/dataframe.py index 79a2c3d564bdf..5fba4ffed6372 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -13,6 +13,7 @@ from __future__ import unicode_literals from datetime import date, datetime +import logging import numpy as np import pandas as pd @@ -26,6 +27,27 @@ INFER_COL_TYPES_SAMPLE_SIZE = 100 +def dedup(l, suffix='__'): + """De-duplicates a list of string by suffixing a counter + + Always returns the same number of entries as provided, and always returns + unique values. + + >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar']))) + foo,bar,bar__1,bar__2 + """ + new_l = [] + seen = {} + for s in l: + if s in seen: + seen[s] += 1 + s += suffix + str(seen[s]) + else: + seen[s] = 0 + new_l.append(s) + return new_l + + class SupersetDataFrame(object): # Mapping numpy dtype.char to generic database types type_map = { @@ -43,19 +65,39 @@ class SupersetDataFrame(object): 'V': None, # raw data (void) } - def __init__(self, df): - self.__df = df.where((pd.notnull(df)), None) + def __init__(self, data, cursor_description, db_engine_spec): + column_names = [] + if cursor_description: + column_names = [col[0] for col in cursor_description] + + self.column_names = dedup( + db_engine_spec.get_normalized_column_names(cursor_description)) + + data = data or [] + self.df = ( + pd.DataFrame(list(data), columns=column_names).infer_objects()) + + self._type_dict = {} + try: + # The driver may not be passing a cursor.description + self._type_dict = { + col: db_engine_spec.get_datatype(cursor_description[i][1]) + for i, col in enumerate(self.column_names) + if cursor_description + } + except Exception as e: + logging.exception(e) @property def size(self): - return len(self.__df.index) + return len(self.df.index) @property def data(self): # work around for https://github.com/pandas-dev/pandas/issues/18372 data = [dict((k, _maybe_box_datetimelike(v)) - for k, v in zip(self.__df.columns, np.atleast_1d(row))) - for row in self.__df.values] + for k, v in zip(self.df.columns, np.atleast_1d(row))) + for row in self.df.values] for d in data: for k, v in list(d.items()): # if an int is too big for Java Script to handle @@ -70,7 +112,8 @@ def db_type(cls, dtype): """Given a numpy dtype, Returns a generic database type""" if isinstance(dtype, ExtensionDtype): return cls.type_map.get(dtype.kind) - return cls.type_map.get(dtype.char) + elif hasattr(dtype, 'char'): + return cls.type_map.get(dtype.char) @classmethod def datetime_conversion_rate(cls, data_series): @@ -105,7 +148,7 @@ def agg_func(cls, dtype, column_name): # consider checking for key substring too. if cls.is_id(column_name): return 'count_distinct' - if (issubclass(dtype.type, np.generic) and + if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and np.issubdtype(dtype, np.number)): return 'sum' return None @@ -116,22 +159,25 @@ def columns(self): :return: dict, with the fields name, type, is_date, is_dim and agg. """ - if self.__df.empty: + if self.df.empty: return None columns = [] - sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index)) - sample = self.__df + sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index)) + sample = self.df if sample_size: - sample = self.__df.sample(sample_size) - for col in self.__df.dtypes.keys(): - col_db_type = self.db_type(self.__df.dtypes[col]) + sample = self.df.sample(sample_size) + for col in self.df.dtypes.keys(): + col_db_type = ( + self._type_dict.get(col) or + self.db_type(self.df.dtypes[col]) + ) column = { 'name': col, - 'agg': self.agg_func(self.__df.dtypes[col], col), + 'agg': self.agg_func(self.df.dtypes[col], col), 'type': col_db_type, - 'is_date': self.is_date(self.__df.dtypes[col]), - 'is_dim': self.is_dimension(self.__df.dtypes[col], col), + 'is_date': self.is_date(self.df.dtypes[col]), + 'is_dim': self.is_dimension(self.df.dtypes[col], col), } if column['type'] in ('OBJECT', None): diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 4f6b22e305460..4181c49d67f5b 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -30,6 +30,7 @@ from flask import g from flask_babel import lazy_gettext as _ import pandas +from past.builtins import basestring import sqlalchemy as sqla from sqlalchemy import select from sqlalchemy.engine import create_engine @@ -85,6 +86,11 @@ def epoch_to_dttm(cls): def epoch_ms_to_dttm(cls): return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)') + @classmethod + def get_datatype(cls, type_code): + if isinstance(type_code, basestring) and len(type_code): + return type_code.upper() + @classmethod def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" @@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec): 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', 'P1W'), ) + type_code_map = {} # loaded from get_datatype only if needed @classmethod def convert_dttm(cls, target_type, dttm): @@ -606,6 +613,23 @@ def adjust_database_uri(cls, uri, selected_schema=None): uri.database = selected_schema return uri + @classmethod + def get_datatype(cls, type_code): + if not cls.type_code_map: + # only import and store if needed at least once + import MySQLdb + ft = MySQLdb.constants.FIELD_TYPE + cls.type_code_map = { + getattr(ft, k): k + for k in dir(ft) + if not k.startswith('_') + } + datatype = type_code + if isinstance(type_code, int): + datatype = cls.type_code_map.get(type_code) + if datatype and isinstance(datatype, basestring) and len(datatype): + return datatype + @classmethod def epoch_to_dttm(cls): return 'from_unixtime({col})' diff --git a/superset/sql_lab.py b/superset/sql_lab.py index df00a2b6b19e0..34a9eeb9e3767 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -10,8 +10,6 @@ from celery.exceptions import SoftTimeLimitExceeded from contextlib2 import contextmanager -import numpy as np -import pandas as pd import sqlalchemy from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool @@ -31,27 +29,6 @@ class SqlLabException(Exception): pass -def dedup(l, suffix='__'): - """De-duplicates a list of string by suffixing a counter - - Always returns the same number of entries as provided, and always returns - unique values. - - >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar']))) - foo,bar,bar__1,bar__2 - """ - new_l = [] - seen = {} - for s in l: - if s in seen: - seen[s] += 1 - s += suffix + str(seen[s]) - else: - seen[s] = 0 - new_l.append(s) - return new_l - - def get_query(query_id, session, retry_count=5): """attemps to get the query and retry if it cannot""" query = None @@ -96,24 +73,6 @@ def session_scope(nullpool): session.close() -def convert_results_to_df(column_names, data): - """Convert raw query results to a DataFrame.""" - column_names = dedup(column_names) - - # check whether the result set has any nested dict columns - if data: - first_row = data[0] - has_dict_col = any([isinstance(c, dict) for c in first_row]) - df_data = list(data) if has_dict_col else np.array(data, dtype=object) - else: - df_data = [] - - cdf = dataframe.SupersetDataFrame( - pd.DataFrame(df_data, columns=column_names)) - - return cdf - - @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT) def get_sql_results( ctask, query_id, rendered_query, return_results=True, store_results=False, @@ -233,7 +192,6 @@ def handle_error(msg): return handle_error(db_engine_spec.extract_error_message(e)) logging.info('Fetching cursor description') - column_names = db_engine_spec.get_normalized_column_names(cursor.description) if conn is not None: conn.commit() @@ -242,7 +200,7 @@ def handle_error(msg): if query.status == utils.QueryStatus.STOPPED: return handle_error('The query has been stopped') - cdf = convert_results_to_df(column_names, data) + cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec) query.rows = cdf.size query.progress = 100 diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 39b7749ae88f7..afaeea9dfb3a2 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -14,7 +14,7 @@ import pandas as pd from past.builtins import basestring -from superset import app, cli, dataframe, db, security_manager +from superset import app, cli, db, security_manager from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query from superset.sql_parse import SupersetQuery @@ -245,55 +245,6 @@ def str_if_basestring(o): def dictify_list_of_dicts(cls, l, k): return {str(o[k]): cls.de_unicode_dict(o) for o in l} - def test_get_columns(self): - main_db = self.get_main_database(db.session) - df = main_db.get_df('SELECT * FROM multiformat_time_series', None) - cdf = dataframe.SupersetDataFrame(df) - - # Making ordering non-deterministic - cols = self.dictify_list_of_dicts(cdf.columns, 'name') - - if main_db.sqlalchemy_uri.startswith('sqlite'): - self.assertEqual(self.dictify_list_of_dicts([ - {'is_date': True, 'type': 'STRING', 'name': 'ds', - 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'ds2', - 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_ms', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_s', 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'string0', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string1', 'is_dim': True}, - {'is_date': True, 'type': 'STRING', 'name': 'string2', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string3', 'is_dim': True}], 'name'), - cols, - ) - else: - self.assertEqual(self.dictify_list_of_dicts([ - {'is_date': True, 'type': 'DATETIME', 'name': 'ds', - 'is_dim': False}, - {'is_date': True, 'type': 'DATETIME', - 'name': 'ds2', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_ms', 'is_dim': False}, - {'agg': 'sum', 'is_date': False, 'type': 'INT', - 'name': 'epoch_s', 'is_dim': False}, - {'is_date': True, 'type': 'STRING', 'name': 'string0', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string1', 'is_dim': True}, - {'is_date': True, 'type': 'STRING', 'name': 'string2', - 'is_dim': False}, - {'is_date': False, 'type': 'STRING', - 'name': 'string3', 'is_dim': True}], 'name'), - cols, - ) - if __name__ == '__main__': unittest.main() diff --git a/tests/core_tests.py b/tests/core_tests.py index 6a4f153eb81fd..f1a01796b734f 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -24,6 +24,7 @@ from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils from superset.connectors.sqla.models import SqlaTable +from superset.db_engine_specs import BaseEngineSpec from superset.models import core as models from superset.models.sql_lab import Query from superset.views.core import DatabaseView @@ -626,8 +627,7 @@ def test_dataframe_timezone(self): (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),), (datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),), ] - df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data), - columns=['data'])) + df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec) data = df.data self.assertDictEqual( data[0], diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py new file mode 100644 index 0000000000000..b56770240b919 --- /dev/null +++ b/tests/dataframe_test.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from superset.dataframe import dedup, SupersetDataFrame +from superset.db_engine_specs import BaseEngineSpec +from .base_tests import SupersetTestCase + + +class SupersetDataFrameTestCase(SupersetTestCase): + def test_dedup(self): + self.assertEquals( + dedup(['foo', 'bar']), + ['foo', 'bar'], + ) + self.assertEquals( + dedup(['foo', 'bar', 'foo', 'bar']), + ['foo', 'bar', 'foo__1', 'bar__1'], + ) + self.assertEquals( + dedup(['foo', 'bar', 'bar', 'bar']), + ['foo', 'bar', 'bar__1', 'bar__2'], + ) + + def test_get_columns_basic(self): + data = [ + ('a1', 'b1', 'c1'), + ('a2', 'b2', 'c2'), + ] + cursor_descr = ( + ('a', 'string'), + ('b', 'string'), + ('c', 'string'), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'STRING', + 'name': 'a', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'STRING', + 'name': 'b', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'STRING', + 'name': 'c', + 'is_dim': True, + }, + ], + ) + + def test_get_columns_with_int(self): + data = [ + ('a1', 1), + ('a2', 2), + ] + cursor_descr = ( + ('a', 'string'), + ('b', 'int'), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'STRING', + 'name': 'a', + 'is_dim': True, + }, { + 'is_date': False, + 'type': 'INT', + 'name': 'b', + 'is_dim': False, + 'agg': 'sum', + }, + ], + ) + + def test_get_columns_type_inference(self): + data = [ + (1.2, 1), + (3.14, 2), + ] + cursor_descr = ( + ('a', None), + ('b', None), + ) + cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec) + self.assertEqual( + cdf.columns, + [ + { + 'is_date': False, + 'type': 'FLOAT', + 'name': 'a', + 'is_dim': False, + 'agg': 'sum', + }, { + 'is_date': False, + 'type': 'INT', + 'name': 'b', + 'is_dim': False, + 'agg': 'sum', + }, + ], + ) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index bdce0b060d020..447914ed5f840 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -7,7 +7,9 @@ import textwrap from superset.db_engine_specs import ( - HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) + BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec, + MySQLEngineSpec, PrestoEngineSpec, +) from superset.models.core import Database from .base_tests import SupersetTestCase @@ -193,3 +195,9 @@ def test_limit_expr_and_semicolon(self): FROM table LIMIT 1000"""), ) + + def test_get_datatype(self): + self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string')) + self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1)) + self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15)) + self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR')) diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 49926f80def1b..a3bb564dd870e 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -12,8 +12,9 @@ from flask_appbuilder.security.sqla import models as ab_models from superset import db, security_manager, utils +from superset.dataframe import SupersetDataFrame +from superset.db_engine_specs import BaseEngineSpec from superset.models.sql_lab import Query -from superset.sql_lab import convert_results_to_df from .base_tests import SupersetTestCase @@ -203,9 +204,13 @@ def test_alias_duplicate(self): raise_on_error=True) def test_df_conversion_no_dict(self): - cols = ['string_col', 'int_col', 'float_col'] + cols = [ + ['string_col', 'string'], + ['int_col', 'int'], + ['float_col', 'float'], + ] data = [['a', 4, 4.0]] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns)) @@ -213,7 +218,7 @@ def test_df_conversion_no_dict(self): def test_df_conversion_tuple(self): cols = ['string_col', 'int_col', 'list_col', 'float_col'] data = [(u'Text', 111, [123], 1.0)] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns)) @@ -221,7 +226,7 @@ def test_df_conversion_tuple(self): def test_df_conversion_dict(self): cols = ['string_col', 'dict_col', 'int_col'] data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]] - cdf = convert_results_to_df(cols, data) + cdf = SupersetDataFrame(data, cols, BaseEngineSpec) self.assertEquals(len(data), cdf.size) self.assertEquals(len(cols), len(cdf.columns))