Skip to content

Commit

Permalink
[BugFix] Allowing limit ordering by post-aggregation metrics (apache#…
Browse files Browse the repository at this point in the history
…4646)

* Allowing limit ordering by post-aggregation metrics

* don't overwrite og dictionaries

* update tests

* python3 compat

* code review comments, add tests, implement it in groupby as well

* python 3 compat for unittest

* more self

* Throw exception when get aggregations is called with postaggs

* Treat adhoc metrics as another aggregation
  • Loading branch information
jeffreythewang authored and michellethomas committed May 23, 2018
1 parent 3ab398f commit ce3c9cc
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 39 deletions.
67 changes: 52 additions & 15 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from superset import conf, db, import_util, security_manager, utils
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.exceptions import MetricPermException
from superset.exceptions import MetricPermException, SupersetException
from superset.models.helpers import (
AuditMixinNullable, ImportMixin, QueryResult, set_perm,
)
Expand All @@ -44,6 +44,7 @@
)

DRUID_TZ = conf.get('DRUID_TZ')
POST_AGG_TYPE = 'postagg'


# Function wrapper because bound methods cannot
Expand Down Expand Up @@ -843,7 +844,7 @@ def find_postaggs_for(postagg_names, metrics_dict):
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name] for name in postagg_names
if metrics_dict[name].metric_type == 'postagg'
if metrics_dict[name].metric_type == POST_AGG_TYPE
]
# Remove post aggregations that were found
for postagg in postagg_metrics:
Expand Down Expand Up @@ -893,8 +894,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)

@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict):
@classmethod
def metrics_and_post_aggs(cls, metrics, metrics_dict):
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
Expand All @@ -903,7 +904,7 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for metric in metrics:
if utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
elif metrics_dict[metric].metric_type != 'postagg':
elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
Expand All @@ -914,9 +915,10 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
cls.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
return list(saved_agg_names), adhoc_agg_configs, post_aggs
aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
return aggs, post_aggs

def values_for_column(self,
column_name,
Expand Down Expand Up @@ -982,16 +984,35 @@ def druid_type_from_adhoc_metric(adhoc_metric):
else:
return column_type + aggregate.capitalize()

def get_aggregations(self, saved_metrics, adhoc_metrics=[]):
@staticmethod
def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
"""
Returns a dictionary of aggregation metric names to aggregation json objects
:param metrics_dict: dictionary of all the metrics
:param saved_metrics: list of saved metric names
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
aggregations = OrderedDict()
for m in self.metrics:
if m.metric_name in saved_metrics:
aggregations[m.metric_name] = m.json_obj
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
metric = metrics_dict[metric_name]
if metric.metric_type == POST_AGG_TYPE:
invalid_metric_names.append(metric_name)
else:
aggregations[metric_name] = metric.json_obj
else:
invalid_metric_names.append(metric_name)
if len(invalid_metric_names) > 0:
raise SupersetException(
_('Metric(s) {} must be aggregations.').format(invalid_metric_names))
for adhoc_metric in adhoc_metrics:
aggregations[adhoc_metric['label']] = {
'fieldName': adhoc_metric['column']['column_name'],
'fieldNames': [adhoc_metric['column']['column_name']],
'type': self.druid_type_from_adhoc_metric(adhoc_metric),
'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric),
'name': adhoc_metric['label'],
}
return aggregations
Expand Down Expand Up @@ -1087,11 +1108,10 @@ def run_query( # noqa / druid
metrics_dict = {m.metric_name: m for m in self.metrics}
columns_dict = {c.column_name: c for c in self.columns}

saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)

aggregations = self.get_aggregations(saved_metrics, adhoc_metrics)
self.check_restricted_metrics(aggregations)

# the dimensions list with dimensionSpecs expanded
Expand Down Expand Up @@ -1143,7 +1163,15 @@ def run_query( # noqa / druid
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
order_by = timeseries_limit_metric
pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric])
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
pre_qry['aggregations'].update(aggs_dict)
pre_qry['post_aggregations'].update(post_aggs_dict)
else:
pre_qry['aggregations'] = aggs_dict
pre_qry['post_aggregations'] = post_aggs_dict
else:
order_by = list(qry['aggregations'].keys())[0]
# Limit on the number of timeseries, doing a two-phases query
Expand Down Expand Up @@ -1193,6 +1221,15 @@ def run_query( # noqa / druid

if timeseries_limit_metric:
order_by = timeseries_limit_metric
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
pre_qry['aggregations'].update(aggs_dict)
pre_qry['post_aggregations'].update(post_aggs_dict)
else:
pre_qry['aggregations'] = aggs_dict
pre_qry['post_aggregations'] = post_aggs_dict

# Limit on the number of timeseries, doing a two-phases query
pre_qry['granularity'] = 'all'
Expand Down
Loading

0 comments on commit ce3c9cc

Please sign in to comment.