Skip to content

Commit

Permalink
Force quoted column aliases for Oracle-like databases (apache#5686)
Browse files Browse the repository at this point in the history
* Replace dataframe label override logic with table column override

* Add mutation to any_date_col

* Linting

* Add mutation to oracle and redshift

* Fine tune how and which labels are mutated

* Implement alias quoting logic for oracle-like databases

* Fix and align column and metric sqla_col methods

* Clean up typos and redundant logic

* Move new attribute to old location

* Linting

* Replace old sqla_col property references with function calls

* Remove redundant calls to mutate_column_label

* Move duplicated logic to common function

* Add db_engine_specs to all sqla_col calls

* Add missing mydb

* Add note about snowflake-sqlalchemy regression

* Make db_engine_spec mandatory in sqla_col

* Small refactoring and cleanup

* Remove db_engine_spec from get_from_clause call

* Make db_engine_spec mandatory in adhoc_metric_to_sa

* Remove redundant mutate_expression_label call

* Add missing db_engine_specs to adhoc_metric_to_sa

* Rename arg label_name to label in get_column_label()

* Rename label function and add docstring

* Remove redundant db_engine_spec args

* Rename col_label to label

* Remove get_column_name wrapper and make direct calls to db_engine_spec

* Remove unneeded db_engine_specs

* Rename sa_ vars to sqla_

(cherry picked from commit 77fe9ef)
  • Loading branch information
villebro authored and betodealmeida committed Oct 30, 2018
1 parent 1cc98e6 commit 7e36077
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 108 deletions.
4 changes: 4 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.

*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
use version 1.1.0 or try a newer version.

See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.

Caching
Expand Down
106 changes: 57 additions & 49 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'

@property
def sqla_col(self):
name = self.column_name
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression:
col = column(self.column_name).label(name)
col = column(self.column_name).label(label)
else:
col = literal_column(self.expression).label(name)
col = literal_column(self.expression).label(label)
return col

@property
def datasource(self):
return self.table

def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
col = self.get_sqla_col(label='__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
Expand Down Expand Up @@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'

@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
return literal_column(self.expression).label(label)

@property
def perm(self):
Expand Down Expand Up @@ -421,11 +421,10 @@ def values_for_column(self, column_name, limit=10000):
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec

qry = (
select([target_col.sqla_col])
.select_from(self.get_from_clause(tp, db_engine_spec))
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
Expand Down Expand Up @@ -474,7 +473,7 @@ def get_sqla_table(self):
tbl.schema = self.schema
return tbl

def get_from_clause(self, template_processor=None, db_engine_spec=None):
def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
Expand All @@ -484,7 +483,7 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()

def adhoc_metric_to_sa(self, metric, cols):
def adhoc_metric_to_sqla(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.
Expand All @@ -493,22 +492,25 @@ def adhoc_metric_to_sa(self, metric, cols):
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expressionType = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
expression_type = metric.get('expressionType')
db_engine_spec = self.database.db_engine_spec
label = db_engine_spec.make_label_compatible(metric.get('label'))

if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
sa_column = column(column_name)
sqla_column = column(column_name)
table_column = cols.get(column_name)

if table_column:
sa_column = table_column.sqla_col

sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sa_metric = literal_column(metric.get('sqlExpression'))
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
sqla_column = table_column.get_sqla_col()

sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
sqla_metric = sqla_metric.label(label)
return sqla_metric
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sqla_metric = literal_column(metric.get('sqlExpression'))
sqla_metric = sqla_metric.label(label)
return sqla_metric
else:
return None

Expand Down Expand Up @@ -566,15 +568,16 @@ def get_sqla_query( # sqla
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col)
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column('COUNT(*)').label('ccount')
main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.make_label_compatible('count'))

select_exprs = []
groupby_exprs = []
Expand All @@ -585,16 +588,16 @@ def get_sqla_query( # sqla
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
outer = col.get_sqla_col()
inner = col.get_sqla_col(col.column_name + '__')

groupby_exprs.append(outer)
select_exprs.append(outer)
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = []

if granularity:
Expand All @@ -618,7 +621,7 @@ def get_sqla_query( # sqla
select_exprs += metrics_exprs
qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor, db_engine_spec)
tbl = self.get_from_clause(template_processor)

if not columns:
qry = qry.group_by(*groupby_exprs)
Expand All @@ -638,33 +641,34 @@ def get_sqla_query( # sqla
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
cond = col_obj.sqla_col.in_(eq)
cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq:
cond = or_(cond, col_obj.sqla_col == None) # noqa
cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
where_clause_and.append(col_obj.sqla_col == eq)
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=':
where_clause_and.append(col_obj.sqla_col != eq)
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>':
where_clause_and.append(col_obj.sqla_col > eq)
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<':
where_clause_and.append(col_obj.sqla_col < eq)
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=':
where_clause_and.append(col_obj.sqla_col >= eq)
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=':
where_clause_and.append(col_obj.sqla_col <= eq)
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE':
where_clause_and.append(col_obj.sqla_col.like(eq))
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL':
where_clause_and.append(col_obj.sqla_col == None) # noqa
where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL':
where_clause_and.append(col_obj.sqla_col != None) # noqa
where_clause_and.append(
col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get('where')
if where:
Expand All @@ -686,7 +690,7 @@ def get_sqla_query( # sqla
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sa(col, cols)
col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col))

if row_limit:
Expand All @@ -712,12 +716,12 @@ def get_sqla_query( # sqla
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
ob = timeseries_limit_metric.sqla_col
ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
Expand Down Expand Up @@ -762,7 +766,7 @@ def _get_top_groups(self, df, dimensions):
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
group.append(col_obj.sqla_col == row[dimension])
group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group))

return or_(*groups)
Expand Down Expand Up @@ -816,6 +820,7 @@ def fetch_metadata(self):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec

for col in table.columns:
try:
Expand Down Expand Up @@ -848,6 +853,9 @@ def fetch_metadata(self):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
Expand Down
4 changes: 1 addition & 3 deletions superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
if cursor_description:
column_names = [col[0] for col in cursor_description]

case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
self.column_names = dedup(column_names)

data = data or []
self.df = (
Expand Down
66 changes: 14 additions & 52 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy import select
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
from tableschema import Table
Expand Down Expand Up @@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names?
force_column_alias_quotes = False
arraysize = None

@classmethod
Expand Down Expand Up @@ -376,55 +376,15 @@ def execute(cls, cursor, query, **kwargs):
cursor.execute(query)

@classmethod
def adjust_df_column_names(cls, df, fd):
"""Based of fields in form_data, return dataframe with new column names
Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
def make_label_compatible(cls, label):
"""
if cls.consistent_case_sensitivity:
return df
else:
return cls.align_df_col_names_with_form_data(df, fd)

@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.
Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
Return a sqlalchemy.sql.elements.quoted_name if the engine requires
quoting of aliases to ensure that select query and query results
have same case.
"""

columns = set()
lowercase_mapping = {}

metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col

rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col

return df.rename(index=str, columns=rename_cols)
if cls.force_column_alias_quotes is True:
return quoted_name(label, True)
return label

@staticmethod
def mutate_expression_label(label):
Expand Down Expand Up @@ -478,7 +438,8 @@ def get_table_names(cls, schema, inspector):

class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
consistent_case_sensitivity = False
force_column_alias_quotes = True

time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
Expand Down Expand Up @@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):

class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
consistent_case_sensitivity = False
force_column_alias_quotes = True


class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False
force_column_alias_quotes = True

time_grain_functions = {
None: '{col}',
Expand All @@ -545,6 +506,7 @@ def convert_dttm(cls, target_type, dttm):
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
Expand Down
Loading

0 comments on commit 7e36077

Please sign in to comment.