Skip to content

Commit

Permalink
Improve database type inference (#4724)
Browse files Browse the repository at this point in the history
* Improve database type inference

Python's DBAPI isn't super clear and homogeneous on the
cursor.description specification, and this PR attempts to improve
inferring the datatypes returned in the cursor.

This work started around Presto's TIMESTAMP type being mishandled as
string as the database driver (pyhive) returns it as a string. The work
here fixes this bug and does a better job at inferring MySQL and Presto types.
It also creates a new method in db_engine_specs allowing for other
databases engines to implement and become more precise on type-inference
as needed.

* Fixing tests

* Adressing comments

* Using infer_objects

* Removing faulty line

* Addressing PrestoSpec redundant method comment

* Fix rebase issue

* Fix tests
  • Loading branch information
mistercrunch authored Jun 28, 2018
1 parent 04fc1d1 commit 777d876
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 117 deletions.
78 changes: 62 additions & 16 deletions superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import unicode_literals

from datetime import date, datetime
import logging

import numpy as np
import pandas as pd
Expand All @@ -26,6 +27,27 @@
INFER_COL_TYPES_SAMPLE_SIZE = 100


def dedup(l, suffix='__'):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
else:
seen[s] = 0
new_l.append(s)
return new_l


class SupersetDataFrame(object):
# Mapping numpy dtype.char to generic database types
type_map = {
Expand All @@ -43,19 +65,39 @@ class SupersetDataFrame(object):
'V': None, # raw data (void)
}

def __init__(self, df):
self.__df = df.where((pd.notnull(df)), None)
def __init__(self, data, cursor_description, db_engine_spec):
column_names = []
if cursor_description:
column_names = [col[0] for col in cursor_description]

self.column_names = dedup(
db_engine_spec.get_normalized_column_names(cursor_description))

data = data or []
self.df = (
pd.DataFrame(list(data), columns=column_names).infer_objects())

self._type_dict = {}
try:
# The driver may not be passing a cursor.description
self._type_dict = {
col: db_engine_spec.get_datatype(cursor_description[i][1])
for i, col in enumerate(self.column_names)
if cursor_description
}
except Exception as e:
logging.exception(e)

@property
def size(self):
return len(self.__df.index)
return len(self.df.index)

@property
def data(self):
# work around for https://github.com/pandas-dev/pandas/issues/18372
data = [dict((k, _maybe_box_datetimelike(v))
for k, v in zip(self.__df.columns, np.atleast_1d(row)))
for row in self.__df.values]
for k, v in zip(self.df.columns, np.atleast_1d(row)))
for row in self.df.values]
for d in data:
for k, v in list(d.items()):
# if an int is too big for Java Script to handle
Expand All @@ -70,7 +112,8 @@ def db_type(cls, dtype):
"""Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind)
return cls.type_map.get(dtype.char)
elif hasattr(dtype, 'char'):
return cls.type_map.get(dtype.char)

@classmethod
def datetime_conversion_rate(cls, data_series):
Expand Down Expand Up @@ -105,7 +148,7 @@ def agg_func(cls, dtype, column_name):
# consider checking for key substring too.
if cls.is_id(column_name):
return 'count_distinct'
if (issubclass(dtype.type, np.generic) and
if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
np.issubdtype(dtype, np.number)):
return 'sum'
return None
Expand All @@ -116,22 +159,25 @@ def columns(self):
:return: dict, with the fields name, type, is_date, is_dim and agg.
"""
if self.__df.empty:
if self.df.empty:
return None

columns = []
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index))
sample = self.__df
sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
sample = self.df
if sample_size:
sample = self.__df.sample(sample_size)
for col in self.__df.dtypes.keys():
col_db_type = self.db_type(self.__df.dtypes[col])
sample = self.df.sample(sample_size)
for col in self.df.dtypes.keys():
col_db_type = (
self._type_dict.get(col) or
self.db_type(self.df.dtypes[col])
)
column = {
'name': col,
'agg': self.agg_func(self.__df.dtypes[col], col),
'agg': self.agg_func(self.df.dtypes[col], col),
'type': col_db_type,
'is_date': self.is_date(self.__df.dtypes[col]),
'is_dim': self.is_dimension(self.__df.dtypes[col], col),
'is_date': self.is_date(self.df.dtypes[col]),
'is_dim': self.is_dimension(self.df.dtypes[col], col),
}

if column['type'] in ('OBJECT', None):
Expand Down
24 changes: 24 additions & 0 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flask import g
from flask_babel import lazy_gettext as _
import pandas
from past.builtins import basestring
import sqlalchemy as sqla
from sqlalchemy import select
from sqlalchemy.engine import create_engine
Expand Down Expand Up @@ -85,6 +86,11 @@ def epoch_to_dttm(cls):
def epoch_ms_to_dttm(cls):
return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)')

@classmethod
def get_datatype(cls, type_code):
if isinstance(type_code, basestring) and len(type_code):
return type_code.upper()

@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
"""Returns engine-specific table metadata"""
Expand Down Expand Up @@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec):
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
'P1W'),
)
type_code_map = {} # loaded from get_datatype only if needed

@classmethod
def convert_dttm(cls, target_type, dttm):
Expand All @@ -606,6 +613,23 @@ def adjust_database_uri(cls, uri, selected_schema=None):
uri.database = selected_schema
return uri

@classmethod
def get_datatype(cls, type_code):
if not cls.type_code_map:
# only import and store if needed at least once
import MySQLdb
ft = MySQLdb.constants.FIELD_TYPE
cls.type_code_map = {
getattr(ft, k): k
for k in dir(ft)
if not k.startswith('_')
}
datatype = type_code
if isinstance(type_code, int):
datatype = cls.type_code_map.get(type_code)
if datatype and isinstance(datatype, basestring) and len(datatype):
return datatype

@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
Expand Down
44 changes: 1 addition & 43 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
import numpy as np
import pandas as pd
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
Expand All @@ -31,27 +29,6 @@ class SqlLabException(Exception):
pass


def dedup(l, suffix='__'):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
else:
seen[s] = 0
new_l.append(s)
return new_l


def get_query(query_id, session, retry_count=5):
"""attemps to get the query and retry if it cannot"""
query = None
Expand Down Expand Up @@ -96,24 +73,6 @@ def session_scope(nullpool):
session.close()


def convert_results_to_df(column_names, data):
"""Convert raw query results to a DataFrame."""
column_names = dedup(column_names)

# check whether the result set has any nested dict columns
if data:
first_row = data[0]
has_dict_col = any([isinstance(c, dict) for c in first_row])
df_data = list(data) if has_dict_col else np.array(data, dtype=object)
else:
df_data = []

cdf = dataframe.SupersetDataFrame(
pd.DataFrame(df_data, columns=column_names))

return cdf


@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
def get_sql_results(
ctask, query_id, rendered_query, return_results=True, store_results=False,
Expand Down Expand Up @@ -233,7 +192,6 @@ def handle_error(msg):
return handle_error(db_engine_spec.extract_error_message(e))

logging.info('Fetching cursor description')
column_names = db_engine_spec.get_normalized_column_names(cursor.description)

if conn is not None:
conn.commit()
Expand All @@ -242,7 +200,7 @@ def handle_error(msg):
if query.status == utils.QueryStatus.STOPPED:
return handle_error('The query has been stopped')

cdf = convert_results_to_df(column_names, data)
cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)

query.rows = cdf.size
query.progress = 100
Expand Down
51 changes: 1 addition & 50 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas as pd
from past.builtins import basestring

from superset import app, cli, dataframe, db, security_manager
from superset import app, cli, db, security_manager
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
Expand Down Expand Up @@ -245,55 +245,6 @@ def str_if_basestring(o):
def dictify_list_of_dicts(cls, l, k):
return {str(o[k]): cls.de_unicode_dict(o) for o in l}

def test_get_columns(self):
main_db = self.get_main_database(db.session)
df = main_db.get_df('SELECT * FROM multiformat_time_series', None)
cdf = dataframe.SupersetDataFrame(df)

# Making ordering non-deterministic
cols = self.dictify_list_of_dicts(cdf.columns, 'name')

if main_db.sqlalchemy_uri.startswith('sqlite'):
self.assertEqual(self.dictify_list_of_dicts([
{'is_date': True, 'type': 'STRING', 'name': 'ds',
'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'ds2',
'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_ms', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_s', 'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'string0',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string1', 'is_dim': True},
{'is_date': True, 'type': 'STRING', 'name': 'string2',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string3', 'is_dim': True}], 'name'),
cols,
)
else:
self.assertEqual(self.dictify_list_of_dicts([
{'is_date': True, 'type': 'DATETIME', 'name': 'ds',
'is_dim': False},
{'is_date': True, 'type': 'DATETIME',
'name': 'ds2', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_ms', 'is_dim': False},
{'agg': 'sum', 'is_date': False, 'type': 'INT',
'name': 'epoch_s', 'is_dim': False},
{'is_date': True, 'type': 'STRING', 'name': 'string0',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string1', 'is_dim': True},
{'is_date': True, 'type': 'STRING', 'name': 'string2',
'is_dim': False},
{'is_date': False, 'type': 'STRING',
'name': 'string3', 'is_dim': True}], 'name'),
cols,
)


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs import BaseEngineSpec
from superset.models import core as models
from superset.models.sql_lab import Query
from superset.views.core import DatabaseView
Expand Down Expand Up @@ -626,8 +627,7 @@ def test_dataframe_timezone(self):
(datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
(datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),),
]
df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data),
columns=['data']))
df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec)
data = df.data
self.assertDictEqual(
data[0],
Expand Down
Loading

0 comments on commit 777d876

Please sign in to comment.