diff --git a/tests/druid_tests.py b/tests/druid_tests.py index 6947a6cd7d1b6..0759afb7db198 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -346,38 +346,38 @@ def test_refresh_metadata(self, PyDruid): self.login(username='admin') cluster = self.get_cluster(PyDruid) cluster.refresh_datasources() + datasource = cluster.datasources[0] - for i, datasource in enumerate(cluster.datasources): - cols = ( - db.session.query(DruidColumn) - .filter(DruidColumn.datasource_id == datasource.id) + cols = ( + db.session.query(DruidColumn) + .filter(DruidColumn.datasource_id == datasource.id) + ) + + for col in cols: + self.assertIn( + col.column_name, + SEGMENT_METADATA[0]['columns'].keys(), ) - for col in cols: - self.assertIn( - col.column_name, - SEGMENT_METADATA[i]['columns'].keys(), - ) + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like('%__metric1')) - ) + self.assertEqual( + {metric.metric_name for metric in metrics}, + {'max__metric1', 'min__metric1', 'sum__metric1'}, + ) + + for metric in metrics: + agg, _ = metric.metric_name.split('__') self.assertEqual( - {metric.metric_name for metric in metrics}, - {'max__metric1', 'min__metric1', 'sum__metric1'}, + json.loads(metric.json)['type'], + 'double{}'.format(agg.capitalize()), ) - for metric in metrics: - agg, _ = metric.metric_name.split('__') - - self.assertEqual( - json.loads(metric.json)['type'], - 'double{}'.format(agg.capitalize()), - ) - @patch('superset.connectors.druid.models.PyDruid') def test_refresh_metadata_augment_type(self, PyDruid): self.login(username='admin') @@ -389,52 +389,60 @@ def test_refresh_metadata_augment_type(self, PyDruid): instance = PyDruid.return_value instance.segment_metadata.return_value = metadata cluster.refresh_datasources() + datasource = cluster.datasources[0] - for i, datasource in enumerate(cluster.datasources): - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like('%__metric1')) - ) + column = ( + db.session.query(DruidColumn) + .filter(DruidColumn.datasource_id == datasource.id) + .filter(DruidColumn.column_name == 'metric1') + ).one() - for metric in metrics: - agg, _ = metric.metric_name.split('__') + self.assertEqual(column.type, 'LONG') - self.assertEqual( - metric.json_obj['type'], - 'long{}'.format(agg.capitalize()), - ) + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) + + for metric in metrics: + agg, _ = metric.metric_name.split('__') + + self.assertEqual( + metric.json_obj['type'], + 'long{}'.format(agg.capitalize()), + ) @patch('superset.connectors.druid.models.PyDruid') def test_refresh_metadata_augment_verbose_name(self, PyDruid): self.login(username='admin') cluster = self.get_cluster(PyDruid) cluster.refresh_datasources() + datasource = cluster.datasources[0] - for i, datasource in enumerate(cluster.datasources): - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like('%__metric1')) - ) + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) - for metric in metrics: - metric.verbose_name = metric.metric_name + for metric in metrics: + metric.verbose_name = metric.metric_name - db.session.commit() + db.session.commit() # The verbose name should not change during a refresh. cluster.refresh_datasources() + datasource = cluster.datasources[0] - for i, datasource in enumerate(cluster.datasources): - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like('%__metric1')) - ) + metrics = ( + db.session.query(DruidMetric) + .filter(DruidMetric.datasource_id == datasource.id) + .filter(DruidMetric.metric_name.like('%__metric1')) + ) - for metric in metrics: - self.assertEqual(metric.verbose_name, metric.metric_name) + for metric in metrics: + self.assertEqual(metric.verbose_name, metric.metric_name) def test_urls(self): cluster = self.get_test_cluster_obj()