From 5b22346432026b44bb30b2f4703cffba100cf780 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 19 Mar 2018 16:33:46 -0400 Subject: [PATCH 1/9] Allowing limit ordering by post-aggregation metrics --- superset/connectors/druid/models.py | 45 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index d514a2f4eecf9..1a34e17643826 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -894,7 +894,7 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) @staticmethod - def metrics_and_post_aggs(metrics, metrics_dict): + def aggs_and_post_aggs(metrics, metrics_dict): # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() @@ -916,7 +916,25 @@ def metrics_and_post_aggs(metrics, metrics_dict): visited_postaggs.add(postagg_name) DruidDatasource.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict) - return list(saved_agg_names), adhoc_agg_configs, post_aggs + aggs = DruidDatasource.get_aggregations(metrics_dict, saved_agg_names) + return aggs, adhoc_agg_configs, post_aggs + + @staticmethod + def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): + aggregations = OrderedDict() + for metric_name in saved_metrics: + if metric_name in metrics_dict: + metric = metrics_dict[metric_name] + if metric.metric_type != 'postagg': + aggregations[metric_name] = metric.json_obj + for adhoc_metric in adhoc_metrics: + aggregations[adhoc_metric['label']] = { + 'fieldName': adhoc_metric['column']['column_name'], + 'fieldNames': [adhoc_metric['column']['column_name']], + 'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), + 'name': adhoc_metric['label'], + } + return aggregations def values_for_column(self, column_name, @@ -982,20 +1000,6 @@ def druid_type_from_adhoc_metric(adhoc_metric): else: return column_type + aggregate.capitalize() - def get_aggregations(self, saved_metrics, adhoc_metrics=[]): - aggregations = OrderedDict() - for m in self.metrics: - if m.metric_name in saved_metrics: - aggregations[m.metric_name] = m.json_obj - for adhoc_metric in adhoc_metrics: - aggregations[adhoc_metric['label']] = { - 'fieldName': adhoc_metric['column']['column_name'], - 'fieldNames': [adhoc_metric['column']['column_name']], - 'type': self.druid_type_from_adhoc_metric(adhoc_metric), - 'name': adhoc_metric['label'], - } - return aggregations - def check_restricted_metrics(self, aggregations): rejected_metrics = [ m.metric_name for m in self.metrics @@ -1087,11 +1091,10 @@ def run_query( # noqa / druid metrics_dict = {m.metric_name: m for m in self.metrics} columns_dict = {c.column_name: c for c in self.columns} - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + aggregations, adhoc_metrics, post_aggs = DruidDatasource.aggs_and_post_aggs( metrics, metrics_dict) - aggregations = self.get_aggregations(saved_metrics, adhoc_metrics) self.check_restricted_metrics(aggregations) # the dimensions list with dimensionSpecs expanded @@ -1143,7 +1146,11 @@ def run_query( # noqa / druid pre_qry = deepcopy(qry) if timeseries_limit_metric: order_by = timeseries_limit_metric - pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric]) + aggs_dict, post_aggs_dict = DruidDatasource.aggs_and_post_aggs( + [timeseries_limit_metric], + metrics_dict) + pre_qry['aggregations'] = aggs_dict + pre_qry['post_aggregations'] = post_aggs_dict else: order_by = list(qry['aggregations'].keys())[0] # Limit on the number of timeseries, doing a two-phases query From bcd04f297b3de9dc25ad5734dc67a4108b4727bd Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 19 Mar 2018 19:19:41 -0400 Subject: [PATCH 2/9] update tests --- superset/connectors/druid/models.py | 6 +++--- tests/druid_func_tests.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 1a34e17643826..a2192b34f75d9 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -894,7 +894,7 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) @staticmethod - def aggs_and_post_aggs(metrics, metrics_dict): + def metrics_and_post_aggs(metrics, metrics_dict): # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() @@ -1091,7 +1091,7 @@ def run_query( # noqa / druid metrics_dict = {m.metric_name: m for m in self.metrics} columns_dict = {c.column_name: c for c in self.columns} - aggregations, adhoc_metrics, post_aggs = DruidDatasource.aggs_and_post_aggs( + aggregations, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) @@ -1146,7 +1146,7 @@ def run_query( # noqa / druid pre_qry = deepcopy(qry) if timeseries_limit_metric: order_by = timeseries_limit_metric - aggs_dict, post_aggs_dict = DruidDatasource.aggs_and_post_aggs( + aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) pre_qry['aggregations'] = aggs_dict diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 22c1f38dc9159..078136be5bc56 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -157,9 +157,9 @@ def test_run_query_no_groupby(self): col1 = DruidColumn(column_name='col1') col2 = DruidColumn(column_name='col2') ds.columns = [col1, col2] - all_metrics = [] + aggs = [] post_aggs = ['some_agg'] - ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) groupby = [] metrics = ['metric1'] ds.get_having_filters = Mock(return_value=[]) @@ -242,9 +242,9 @@ def test_run_query_single_groupby(self): col1 = DruidColumn(column_name='col1') col2 = DruidColumn(column_name='col2') ds.columns = [col1, col2] - all_metrics = ['metric1'] + aggs = ['metric1'] post_aggs = ['some_agg'] - ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) groupby = ['col1'] metrics = ['metric1'] ds.get_having_filters = Mock(return_value=[]) @@ -316,9 +316,9 @@ def test_run_query_multiple_groupby(self): col1 = DruidColumn(column_name='col1') col2 = DruidColumn(column_name='col2') ds.columns = [col1, col2] - all_metrics = [] + aggs = [] post_aggs = ['some_agg'] - ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) groupby = ['col1', 'col2'] metrics = ['metric1'] ds.get_having_filters = Mock(return_value=[]) @@ -512,10 +512,10 @@ def depends_on(index, fields): depends_on('I', ['H', 'K']) depends_on('J', 'K') depends_on('K', ['m8', 'm9']) - all_metrics, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs( + aggs, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - expected_metrics = set(all_metrics) - self.assertEqual(9, len(all_metrics)) + expected_metrics = set(aggs.keys()) + self.assertEqual(9, len(aggs)) for i in range(1, 10): expected_metrics.remove('m' + str(i)) self.assertEqual(0, len(expected_metrics)) From a712eb507ce61997aa82e6d2b2c29c5a01af6c70 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 19 Mar 2018 19:06:15 -0400 Subject: [PATCH 3/9] don't overwrite og dictionaries --- superset/connectors/druid/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index a2192b34f75d9..6a3ee340c30da 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1149,8 +1149,8 @@ def run_query( # noqa / druid aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) - pre_qry['aggregations'] = aggs_dict - pre_qry['post_aggregations'] = post_aggs_dict + pre_qry['aggregations'].update(aggs_dict) + pre_qry['post_aggregations'].update(post_aggs_dict) else: order_by = list(qry['aggregations'].keys())[0] # Limit on the number of timeseries, doing a two-phases query From 88c8a1e83e2cf362b3f105730349535affb81ad8 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 19 Mar 2018 20:00:55 -0400 Subject: [PATCH 4/9] python3 compat --- tests/druid_func_tests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 078136be5bc56..44889195826df 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -596,7 +596,7 @@ def test_metrics_and_post_aggs(self): saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - assert saved_metrics == ['some_sum'] + assert set(saved_metrics.keys()) == {'some_sum'} assert adhoc_metrics == [] assert post_aggs == {} @@ -604,7 +604,7 @@ def test_metrics_and_post_aggs(self): saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - assert saved_metrics == [] + assert set(saved_metrics.keys()) == set([]) assert adhoc_metrics == [adhoc_metric] assert post_aggs == {} @@ -612,7 +612,7 @@ def test_metrics_and_post_aggs(self): saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - assert saved_metrics == ['some_sum'] + assert set(saved_metrics.keys()) == {'some_sum'} assert adhoc_metrics == [adhoc_metric] assert post_aggs == {} @@ -621,7 +621,7 @@ def test_metrics_and_post_aggs(self): metrics, metrics_dict) result_postaggs = set(['quantile_p95']) - assert saved_metrics == ['a_histogram'] + assert set(saved_metrics.keys()) == {'a_histogram'} assert adhoc_metrics == [] assert set(post_aggs.keys()) == result_postaggs @@ -630,7 +630,7 @@ def test_metrics_and_post_aggs(self): metrics, metrics_dict) result_postaggs = set(['aCustomPostAgg']) - assert saved_metrics == ['aCustomMetric'] + assert set(saved_metrics.keys()) == {'aCustomMetric'} assert adhoc_metrics == [] assert set(post_aggs.keys()) == result_postaggs From 25b9f7560c4f201ecbe04502c4c20c7cd07bf48d Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 20 Mar 2018 13:51:07 -0400 Subject: [PATCH 5/9] code review comments, add tests, implement it in groupby as well --- superset/connectors/druid/models.py | 25 +++++-- tests/druid_func_tests.py | 108 ++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 6 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 6a3ee340c30da..c0b04f2130a8a 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -893,8 +893,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict) post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) - @staticmethod - def metrics_and_post_aggs(metrics, metrics_dict): + @classmethod + def metrics_and_post_aggs(cls, metrics, metrics_dict): # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() @@ -914,9 +914,9 @@ def metrics_and_post_aggs(metrics, metrics_dict): for postagg_name in postagg_names: postagg = metrics_dict[postagg_name] visited_postaggs.add(postagg_name) - DruidDatasource.resolve_postagg( + cls.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict) - aggs = DruidDatasource.get_aggregations(metrics_dict, saved_agg_names) + aggs = cls.get_aggregations(metrics_dict, saved_agg_names) return aggs, adhoc_agg_configs, post_aggs @staticmethod @@ -1149,8 +1149,12 @@ def run_query( # noqa / druid aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) - pre_qry['aggregations'].update(aggs_dict) - pre_qry['post_aggregations'].update(post_aggs_dict) + if phase == 1: + pre_qry['aggregations'].update(aggs_dict) + pre_qry['post_aggregations'].update(post_aggs_dict) + else: + pre_qry['aggregations'] = aggs_dict + pre_qry['post_aggregations'] = post_aggs_dict else: order_by = list(qry['aggregations'].keys())[0] # Limit on the number of timeseries, doing a two-phases query @@ -1200,6 +1204,15 @@ def run_query( # noqa / druid if timeseries_limit_metric: order_by = timeseries_limit_metric + aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( + [timeseries_limit_metric], + metrics_dict) + if phase == 1: + pre_qry['aggregations'].update(aggs_dict) + pre_qry['post_aggregations'].update(post_aggs_dict) + else: + pre_qry['aggregations'] = aggs_dict + pre_qry['post_aggregations'] = post_aggs_dict # Limit on the number of timeseries, doing a two-phases query pre_qry['granularity'] = 'all' diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 44889195826df..c5d1cb7272d91 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -663,3 +663,111 @@ def test_druid_type_from_adhoc_metric(self): 'label': 'My Adhoc Metric', }) assert(druid_type == 'cardinality') + + def test_run_query_order_by_metrics(self): + client = Mock() + client.query_builder.last_query.query_dict = {'mock': 0} + from_dttm = Mock() + to_dttm = Mock() + ds = DruidDatasource(datasource_name='datasource') + ds.get_having_filters = Mock(return_value=[]) + dim1 = DruidColumn(column_name='dim1') + dim2 = DruidColumn(column_name='dim2') + metrics_dict = { + 'count1': DruidMetric( + metric_name='count1', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'count1'}), + ), + 'sum1': DruidMetric( + metric_name='sum1', + metric_type='doubleSum', + json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}), + ), + 'sum2': DruidMetric( + metric_name='sum2', + metric_type='doubleSum', + json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}), + ), + 'div1': DruidMetric( + metric_name='div1', + metric_type='postagg', + json=json.dumps({ + 'fn': '/', + 'type': 'arithmetic', + 'name': 'div1', + 'fields': [ + { + 'fieldName': 'sum1', + 'type': 'fieldAccess', + }, + { + 'fieldName': 'sum2', + 'type': 'fieldAccess', + }, + ], + }) + ) + } + ds.columns = [dim1, dim2] + ds.metrics = list(metrics_dict.values()) + + groupby = ['dim1'] + metrics = ['count1'] + granularity = 'all' + # get the counts of the top 5 'dim1's, order by 'sum1' + ds.run_query( + groupby, metrics, granularity, from_dttm, to_dttm, + timeseries_limit=5, timeseries_limit_metric='sum1', + client=client, order_desc=True, filter=[], + ) + qry_obj = client.topn.call_args_list[0][1] + self.assertEqual('dim1', qry_obj['dimension']) + self.assertEqual('sum1', qry_obj['metric']) + aggregations = qry_obj['aggregations'] + post_aggregations = qry_obj['post_aggregations'] + self.assertItemsEqual(['count1', 'sum1'], list(aggregations.keys())) + self.assertItemsEqual([], list(post_aggregations.keys())) + + # get the counts of the top 5 'dim1's, order by 'div1' + ds.run_query( + groupby, metrics, granularity, from_dttm, to_dttm, + timeseries_limit=5, timeseries_limit_metric='div1', + client=client, order_desc=True, filter=[], + ) + qry_obj = client.topn.call_args_list[1][1] + self.assertEqual('dim1', qry_obj['dimension']) + self.assertEqual('div1', qry_obj['metric']) + aggregations = qry_obj['aggregations'] + post_aggregations = qry_obj['post_aggregations'] + self.assertItemsEqual(['count1', 'sum1', 'sum2'], list(aggregations.keys())) + self.assertItemsEqual(['div1'], list(post_aggregations.keys())) + + groupby = ['dim1', 'dim2'] + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' + ds.run_query( + groupby, metrics, granularity, from_dttm, to_dttm, + timeseries_limit=5, timeseries_limit_metric='sum1', + client=client, order_desc=True, filter=[], + ) + qry_obj = client.groupby.call_args_list[0][1] + self.assertItemsEqual(['dim1', 'dim2'], qry_obj['dimensions']) + self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension']) + aggregations = qry_obj['aggregations'] + post_aggregations = qry_obj['post_aggregations'] + self.assertItemsEqual(['count1', 'sum1'], list(aggregations.keys())) + self.assertItemsEqual([], list(post_aggregations.keys())) + + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' + ds.run_query( + groupby, metrics, granularity, from_dttm, to_dttm, + timeseries_limit=5, timeseries_limit_metric='div1', + client=client, order_desc=True, filter=[], + ) + qry_obj = client.groupby.call_args_list[1][1] + self.assertItemsEqual(['dim1', 'dim2'], qry_obj['dimensions']) + self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension']) + aggregations = qry_obj['aggregations'] + post_aggregations = qry_obj['post_aggregations'] + self.assertItemsEqual(['count1', 'sum1', 'sum2'], list(aggregations.keys())) + self.assertItemsEqual(['div1'], list(post_aggregations.keys())) From 89db89ea30de04f5a3cdceaf3b315a06974f7f8d Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 20 Mar 2018 14:47:38 -0400 Subject: [PATCH 6/9] python 3 compat for unittest --- tests/druid_func_tests.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index c5d1cb7272d91..c69c627d45f0e 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -706,8 +706,8 @@ def test_run_query_order_by_metrics(self): 'type': 'fieldAccess', }, ], - }) - ) + }), + ), } ds.columns = [dim1, dim2] ds.metrics = list(metrics_dict.values()) @@ -726,8 +726,8 @@ def test_run_query_order_by_metrics(self): self.assertEqual('sum1', qry_obj['metric']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertItemsEqual(['count1', 'sum1'], list(aggregations.keys())) - self.assertItemsEqual([], list(post_aggregations.keys())) + self.assertEqual(set(['count1', 'sum1']), set(aggregations.keys())) + self.assertEqual(set([]), set(post_aggregations.keys())) # get the counts of the top 5 'dim1's, order by 'div1' ds.run_query( @@ -740,8 +740,8 @@ def test_run_query_order_by_metrics(self): self.assertEqual('div1', qry_obj['metric']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertItemsEqual(['count1', 'sum1', 'sum2'], list(aggregations.keys())) - self.assertItemsEqual(['div1'], list(post_aggregations.keys())) + self.assertEqual(set(['count1', 'sum1', 'sum2']), set(aggregations.keys())) + self.assertEqual(set(['div1']), set(post_aggregations.keys())) groupby = ['dim1', 'dim2'] # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' @@ -751,12 +751,12 @@ def test_run_query_order_by_metrics(self): client=client, order_desc=True, filter=[], ) qry_obj = client.groupby.call_args_list[0][1] - self.assertItemsEqual(['dim1', 'dim2'], qry_obj['dimensions']) + self.assertEqual(set(['dim1', 'dim2']), set(qry_obj['dimensions'])) self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertItemsEqual(['count1', 'sum1'], list(aggregations.keys())) - self.assertItemsEqual([], list(post_aggregations.keys())) + self.assertEqual(set(['count1', 'sum1']), set(aggregations.keys())) + self.assertEqual(set([]), set(post_aggregations.keys())) # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' ds.run_query( @@ -765,9 +765,9 @@ def test_run_query_order_by_metrics(self): client=client, order_desc=True, filter=[], ) qry_obj = client.groupby.call_args_list[1][1] - self.assertItemsEqual(['dim1', 'dim2'], qry_obj['dimensions']) + self.assertEqual(set(['dim1', 'dim2']), set(qry_obj['dimensions'])) self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertItemsEqual(['count1', 'sum1', 'sum2'], list(aggregations.keys())) - self.assertItemsEqual(['div1'], list(post_aggregations.keys())) + self.assertEqual(set(['count1', 'sum1', 'sum2']), set(aggregations.keys())) + self.assertEqual(set(['div1']), set(post_aggregations.keys())) From 98300e5a5c04331b6eb04d8505cc982bea37b428 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 20 Mar 2018 15:10:38 -0400 Subject: [PATCH 7/9] more self --- superset/connectors/druid/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index c0b04f2130a8a..25b4121d73c9f 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1146,7 +1146,7 @@ def run_query( # noqa / druid pre_qry = deepcopy(qry) if timeseries_limit_metric: order_by = timeseries_limit_metric - aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( + aggs_dict, adhoc_dict, post_aggs_dict = self.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: @@ -1204,7 +1204,7 @@ def run_query( # noqa / druid if timeseries_limit_metric: order_by = timeseries_limit_metric - aggs_dict, adhoc_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( + aggs_dict, adhoc_dict, post_aggs_dict = self.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: From e85ebfd0d610a5b1ec422244282e1d8f3ec2ad93 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Fri, 23 Mar 2018 13:06:04 -0400 Subject: [PATCH 8/9] Throw exception when get aggregations is called with postaggs --- superset/connectors/druid/models.py | 25 +++++++++-- tests/druid_func_tests.py | 67 ++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 25b4121d73c9f..c8d768ac97b4f 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -35,7 +35,7 @@ from superset import conf, db, import_util, security_manager, utils from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric -from superset.exceptions import MetricPermException +from superset.exceptions import MetricPermException, SupersetException from superset.models.helpers import ( AuditMixinNullable, ImportMixin, QueryResult, set_perm, ) @@ -44,6 +44,7 @@ ) DRUID_TZ = conf.get('DRUID_TZ') +POST_AGG_TYPE = 'postagg' # Function wrapper because bound methods cannot @@ -843,7 +844,7 @@ def find_postaggs_for(postagg_names, metrics_dict): """Return a list of metrics that are post aggregations""" postagg_metrics = [ metrics_dict[name] for name in postagg_names - if metrics_dict[name].metric_type == 'postagg' + if metrics_dict[name].metric_type == POST_AGG_TYPE ] # Remove post aggregations that were found for postagg in postagg_metrics: @@ -903,7 +904,7 @@ def metrics_and_post_aggs(cls, metrics, metrics_dict): for metric in metrics: if utils.is_adhoc_metric(metric): adhoc_agg_configs.append(metric) - elif metrics_dict[metric].metric_type != 'postagg': + elif metrics_dict[metric].metric_type != POST_AGG_TYPE: saved_agg_names.add(metric) else: postagg_names.append(metric) @@ -921,12 +922,28 @@ def metrics_and_post_aggs(cls, metrics, metrics_dict): @staticmethod def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): + """ + Returns a dictionary of aggregation metric names to aggregation json objects + + :param metrics_dict: dictionary of all the metrics + :param saved_metrics: list of saved metric names + :param adhoc_metrics: list of adhoc metric names + :raise SupersetException: if one or more metric names are not aggregations + """ aggregations = OrderedDict() + invalid_metric_names = [] for metric_name in saved_metrics: if metric_name in metrics_dict: metric = metrics_dict[metric_name] - if metric.metric_type != 'postagg': + if metric.metric_type == POST_AGG_TYPE: + invalid_metric_names.append(metric_name) + else: aggregations[metric_name] = metric.json_obj + else: + invalid_metric_names.append(metric_name) + if len(invalid_metric_names) > 0: + raise SupersetException( + _('Metric(s) {} must be aggregations.').format(invalid_metric_names)) for adhoc_metric in adhoc_metrics: aggregations[adhoc_metric['label']] = { 'fieldName': adhoc_metric['column']['column_name'], diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index c69c627d45f0e..7491561f5788d 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -14,6 +14,7 @@ from superset.connectors.druid.models import ( DruidColumn, DruidDatasource, DruidMetric, ) +from superset.exceptions import SupersetException def mock_metric(metric_name, is_postagg=False): @@ -726,8 +727,8 @@ def test_run_query_order_by_metrics(self): self.assertEqual('sum1', qry_obj['metric']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertEqual(set(['count1', 'sum1']), set(aggregations.keys())) - self.assertEqual(set([]), set(post_aggregations.keys())) + self.assertEqual({'count1', 'sum1'}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) # get the counts of the top 5 'dim1's, order by 'div1' ds.run_query( @@ -740,8 +741,8 @@ def test_run_query_order_by_metrics(self): self.assertEqual('div1', qry_obj['metric']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertEqual(set(['count1', 'sum1', 'sum2']), set(aggregations.keys())) - self.assertEqual(set(['div1']), set(post_aggregations.keys())) + self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys())) + self.assertEqual({'div1'}, set(post_aggregations.keys())) groupby = ['dim1', 'dim2'] # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' @@ -751,12 +752,12 @@ def test_run_query_order_by_metrics(self): client=client, order_desc=True, filter=[], ) qry_obj = client.groupby.call_args_list[0][1] - self.assertEqual(set(['dim1', 'dim2']), set(qry_obj['dimensions'])) + self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions'])) self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertEqual(set(['count1', 'sum1']), set(aggregations.keys())) - self.assertEqual(set([]), set(post_aggregations.keys())) + self.assertEqual({'count1', 'sum1'}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' ds.run_query( @@ -765,9 +766,55 @@ def test_run_query_order_by_metrics(self): client=client, order_desc=True, filter=[], ) qry_obj = client.groupby.call_args_list[1][1] - self.assertEqual(set(['dim1', 'dim2']), set(qry_obj['dimensions'])) + self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions'])) self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension']) aggregations = qry_obj['aggregations'] post_aggregations = qry_obj['post_aggregations'] - self.assertEqual(set(['count1', 'sum1', 'sum2']), set(aggregations.keys())) - self.assertEqual(set(['div1']), set(post_aggregations.keys())) + self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys())) + self.assertEqual({'div1'}, set(post_aggregations.keys())) + + def test_get_aggregations(self): + ds = DruidDatasource(datasource_name='datasource') + metrics_dict = { + 'sum1': DruidMetric( + metric_name='sum1', + metric_type='doubleSum', + json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}), + ), + 'sum2': DruidMetric( + metric_name='sum2', + metric_type='doubleSum', + json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}), + ), + 'div1': DruidMetric( + metric_name='div1', + metric_type='postagg', + json=json.dumps({ + 'fn': '/', + 'type': 'arithmetic', + 'name': 'div1', + 'fields': [ + { + 'fieldName': 'sum1', + 'type': 'fieldAccess', + }, + { + 'fieldName': 'sum2', + 'type': 'fieldAccess', + }, + ], + }), + ), + } + metric_names = ['sum1', 'sum2'] + aggs = ds.get_aggregations(metrics_dict, metric_names) + expected_agg = {name: metrics_dict[name].json_obj for name in metric_names} + self.assertEqual(expected_agg, aggs) + + metric_names = ['sum1', 'col1'] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names) + + metric_names = ['sum1', 'div1'] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names) From 10f4c01317148d56a029c502ab0021666187f70d Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Fri, 30 Mar 2018 18:42:38 -0400 Subject: [PATCH 9/9] Treat adhoc metrics as another aggregation --- superset/connectors/druid/models.py | 76 ++++++++++++++--------------- tests/druid_func_tests.py | 21 +++----- 2 files changed, 46 insertions(+), 51 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index c8d768ac97b4f..bd684a5988da6 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -917,41 +917,8 @@ def metrics_and_post_aggs(cls, metrics, metrics_dict): visited_postaggs.add(postagg_name) cls.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict) - aggs = cls.get_aggregations(metrics_dict, saved_agg_names) - return aggs, adhoc_agg_configs, post_aggs - - @staticmethod - def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): - """ - Returns a dictionary of aggregation metric names to aggregation json objects - - :param metrics_dict: dictionary of all the metrics - :param saved_metrics: list of saved metric names - :param adhoc_metrics: list of adhoc metric names - :raise SupersetException: if one or more metric names are not aggregations - """ - aggregations = OrderedDict() - invalid_metric_names = [] - for metric_name in saved_metrics: - if metric_name in metrics_dict: - metric = metrics_dict[metric_name] - if metric.metric_type == POST_AGG_TYPE: - invalid_metric_names.append(metric_name) - else: - aggregations[metric_name] = metric.json_obj - else: - invalid_metric_names.append(metric_name) - if len(invalid_metric_names) > 0: - raise SupersetException( - _('Metric(s) {} must be aggregations.').format(invalid_metric_names)) - for adhoc_metric in adhoc_metrics: - aggregations[adhoc_metric['label']] = { - 'fieldName': adhoc_metric['column']['column_name'], - 'fieldNames': [adhoc_metric['column']['column_name']], - 'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), - 'name': adhoc_metric['label'], - } - return aggregations + aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs) + return aggs, post_aggs def values_for_column(self, column_name, @@ -1017,6 +984,39 @@ def druid_type_from_adhoc_metric(adhoc_metric): else: return column_type + aggregate.capitalize() + @staticmethod + def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): + """ + Returns a dictionary of aggregation metric names to aggregation json objects + + :param metrics_dict: dictionary of all the metrics + :param saved_metrics: list of saved metric names + :param adhoc_metrics: list of adhoc metric names + :raise SupersetException: if one or more metric names are not aggregations + """ + aggregations = OrderedDict() + invalid_metric_names = [] + for metric_name in saved_metrics: + if metric_name in metrics_dict: + metric = metrics_dict[metric_name] + if metric.metric_type == POST_AGG_TYPE: + invalid_metric_names.append(metric_name) + else: + aggregations[metric_name] = metric.json_obj + else: + invalid_metric_names.append(metric_name) + if len(invalid_metric_names) > 0: + raise SupersetException( + _('Metric(s) {} must be aggregations.').format(invalid_metric_names)) + for adhoc_metric in adhoc_metrics: + aggregations[adhoc_metric['label']] = { + 'fieldName': adhoc_metric['column']['column_name'], + 'fieldNames': [adhoc_metric['column']['column_name']], + 'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), + 'name': adhoc_metric['label'], + } + return aggregations + def check_restricted_metrics(self, aggregations): rejected_metrics = [ m.metric_name for m in self.metrics @@ -1108,7 +1108,7 @@ def run_query( # noqa / druid metrics_dict = {m.metric_name: m for m in self.metrics} columns_dict = {c.column_name: c for c in self.columns} - aggregations, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) @@ -1163,7 +1163,7 @@ def run_query( # noqa / druid pre_qry = deepcopy(qry) if timeseries_limit_metric: order_by = timeseries_limit_metric - aggs_dict, adhoc_dict, post_aggs_dict = self.metrics_and_post_aggs( + aggs_dict, post_aggs_dict = self.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: @@ -1221,7 +1221,7 @@ def run_query( # noqa / druid if timeseries_limit_metric: order_by = timeseries_limit_metric - aggs_dict, adhoc_dict, post_aggs_dict = self.metrics_and_post_aggs( + aggs_dict, post_aggs_dict = self.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 7491561f5788d..c47849433cec8 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -513,7 +513,7 @@ def depends_on(index, fields): depends_on('I', ['H', 'K']) depends_on('J', 'K') depends_on('K', ['m8', 'm9']) - aggs, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs( + aggs, postaggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) expected_metrics = set(aggs.keys()) self.assertEqual(9, len(aggs)) @@ -594,45 +594,40 @@ def test_metrics_and_post_aggs(self): } metrics = ['some_sum'] - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) assert set(saved_metrics.keys()) == {'some_sum'} - assert adhoc_metrics == [] assert post_aggs == {} metrics = [adhoc_metric] - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - assert set(saved_metrics.keys()) == set([]) - assert adhoc_metrics == [adhoc_metric] + assert set(saved_metrics.keys()) == set([adhoc_metric['label']]) assert post_aggs == {} metrics = ['some_sum', adhoc_metric] - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) - assert set(saved_metrics.keys()) == {'some_sum'} - assert adhoc_metrics == [adhoc_metric] + assert set(saved_metrics.keys()) == {'some_sum', adhoc_metric['label']} assert post_aggs == {} metrics = ['quantile_p95'] - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) result_postaggs = set(['quantile_p95']) assert set(saved_metrics.keys()) == {'a_histogram'} - assert adhoc_metrics == [] assert set(post_aggs.keys()) == result_postaggs metrics = ['aCustomPostAgg'] - saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) result_postaggs = set(['aCustomPostAgg']) assert set(saved_metrics.keys()) == {'aCustomMetric'} - assert adhoc_metrics == [] assert set(post_aggs.keys()) == result_postaggs def test_druid_type_from_adhoc_metric(self):