Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allowing limit ordering by post-aggregation metrics #4646

Merged
merged 9 commits into from
Apr 3, 2018
67 changes: 52 additions & 15 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from superset import conf, db, import_util, security_manager, utils
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.exceptions import MetricPermException
from superset.exceptions import MetricPermException, SupersetException
from superset.models.helpers import (
AuditMixinNullable, ImportMixin, QueryResult, set_perm,
)
Expand All @@ -44,6 +44,7 @@
)

DRUID_TZ = conf.get('DRUID_TZ')
POST_AGG_TYPE = 'postagg'


# Function wrapper because bound methods cannot
Expand Down Expand Up @@ -843,7 +844,7 @@ def find_postaggs_for(postagg_names, metrics_dict):
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name] for name in postagg_names
if metrics_dict[name].metric_type == 'postagg'
if metrics_dict[name].metric_type == POST_AGG_TYPE
]
# Remove post aggregations that were found
for postagg in postagg_metrics:
Expand Down Expand Up @@ -893,8 +894,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)

@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict):
@classmethod
def metrics_and_post_aggs(cls, metrics, metrics_dict):
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
Expand All @@ -903,7 +904,7 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for metric in metrics:
if utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
elif metrics_dict[metric].metric_type != 'postagg':
elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
Expand All @@ -914,9 +915,10 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
cls.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
return list(saved_agg_names), adhoc_agg_configs, post_aggs
aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
return aggs, post_aggs

def values_for_column(self,
column_name,
Expand Down Expand Up @@ -982,16 +984,35 @@ def druid_type_from_adhoc_metric(adhoc_metric):
else:
return column_type + aggregate.capitalize()

def get_aggregations(self, saved_metrics, adhoc_metrics=[]):
@staticmethod
def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
"""
Returns a dictionary of aggregation metric names to aggregation json objects

:param metrics_dict: dictionary of all the metrics
:param saved_metrics: list of saved metric names
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
aggregations = OrderedDict()
for m in self.metrics:
if m.metric_name in saved_metrics:
aggregations[m.metric_name] = m.json_obj
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
metric = metrics_dict[metric_name]
if metric.metric_type == POST_AGG_TYPE:
invalid_metric_names.append(metric_name)
else:
aggregations[metric_name] = metric.json_obj
else:
invalid_metric_names.append(metric_name)
if len(invalid_metric_names) > 0:
raise SupersetException(
_('Metric(s) {} must be aggregations.').format(invalid_metric_names))
for adhoc_metric in adhoc_metrics:
aggregations[adhoc_metric['label']] = {
'fieldName': adhoc_metric['column']['column_name'],
'fieldNames': [adhoc_metric['column']['column_name']],
'type': self.druid_type_from_adhoc_metric(adhoc_metric),
'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric),
'name': adhoc_metric['label'],
}
return aggregations
Expand Down Expand Up @@ -1087,11 +1108,10 @@ def run_query( # noqa / druid
metrics_dict = {m.metric_name: m for m in self.metrics}
columns_dict = {c.column_name: c for c in self.columns}

saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)

aggregations = self.get_aggregations(saved_metrics, adhoc_metrics)
self.check_restricted_metrics(aggregations)

# the dimensions list with dimensionSpecs expanded
Expand Down Expand Up @@ -1143,7 +1163,15 @@ def run_query( # noqa / druid
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
order_by = timeseries_limit_metric
pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric])
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
pre_qry['aggregations'].update(aggs_dict)
pre_qry['post_aggregations'].update(post_aggs_dict)
else:
pre_qry['aggregations'] = aggs_dict
pre_qry['post_aggregations'] = post_aggs_dict
else:
order_by = list(qry['aggregations'].keys())[0]
# Limit on the number of timeseries, doing a two-phases query
Expand Down Expand Up @@ -1193,6 +1221,15 @@ def run_query( # noqa / druid

if timeseries_limit_metric:
order_by = timeseries_limit_metric
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
pre_qry['aggregations'].update(aggs_dict)
pre_qry['post_aggregations'].update(post_aggs_dict)
else:
pre_qry['aggregations'] = aggs_dict
pre_qry['post_aggregations'] = post_aggs_dict

# Limit on the number of timeseries, doing a two-phases query
pre_qry['granularity'] = 'all'
Expand Down
Loading