diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 171d210ec3c..bed80c0397c 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,6 +1,6 @@ import functools import time -from typing import List +from typing import List, Dict, Any, Set, Tuple, Optional from dbt.logger import ( GLOBAL_LOGGER as logger, @@ -148,7 +148,9 @@ def run_hooks(self, adapter, hook_type: RunHookType, extra_context): with TextOnly(): print_timestamped_line("") - def safe_run_hooks(self, adapter, hook_type: RunHookType, extra_context): + def safe_run_hooks( + self, adapter, hook_type: RunHookType, extra_context: Dict[str, Any] + ) -> None: try: self.run_hooks(adapter, hook_type, extra_context) except dbt.exceptions.RuntimeException: @@ -178,18 +180,23 @@ def before_run(self, adapter, selected_uids): self.safe_run_hooks(adapter, RunHookType.Start, {}) def after_run(self, adapter, results): - # in on-run-end hooks, provide the value 'schemas', which is a list of - # unique schemas that successfully executed models were in - # errored failed skipped - schemas = list(set( - r.node.schema for r in results + # in on-run-end hooks, provide the value 'database_schemas', which is a + # list of unique database, schema pairs that successfully executed + # models were in. for backwards compatibility, include the old + # 'schemas', which did not include database information. + database_schema_set: Set[Tuple[Optional[str], str]] = { + (r.node.database, r.node.schema) for r in results if not any((r.error is not None, r.fail, r.skipped)) - )) - + } self._total_executed += len(results) + + extras = { + 'schemas': list({s for _, s in database_schema_set}), + 'results': results, + 'database_schemas': list(database_schema_set), + } with adapter.connection_named('master'): - self.safe_run_hooks(adapter, RunHookType.End, - {'schemas': schemas, 'results': results}) + self.safe_run_hooks(adapter, RunHookType.End, extras) def after_hooks(self, adapter, results, elapsed): self.print_results_line(results, elapsed) diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index da0182dc309..62c87d12372 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -1,5 +1,7 @@ +from typing import Dict, Any + from dbt.node_runners import TestRunner -from dbt.node_types import NodeType +from dbt.node_types import NodeType, RunHookType from dbt.task.run import RunTask @@ -12,7 +14,9 @@ class TestTask(RunTask): def raise_on_first_error(self): return False - def safe_run_hooks(self, adapter, hook_type, extra_context): + def safe_run_hooks( + self, adapter, hook_type: RunHookType, extra_context: Dict[str, Any] + ) -> None: # Don't execute on-run-* hooks for tests pass diff --git a/test/integration/014_hook_tests/test_run_hooks.py b/test/integration/014_hook_tests/test_run_hooks.py index ba6d974286f..d39dc303ecc 100644 --- a/test/integration/014_hook_tests/test_run_hooks.py +++ b/test/integration/014_hook_tests/test_run_hooks.py @@ -1,5 +1,6 @@ from test.integration.base import DBTIntegrationTest, use_profile + class TestPrePostRunHooks(DBTIntegrationTest): def setUp(self): @@ -45,7 +46,9 @@ def project_config(self): "create table {{ target.schema }}.end_hook_order_test ( id int )", "drop table {{ target.schema }}.end_hook_order_test", "create table {{ target.schema }}.schemas ( schema text )", - "insert into {{ target.schema }}.schemas values ({% for schema in schemas %}( '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %})", + "insert into {{ target.schema }}.schemas (schema) values {% for schema in schemas %}( '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %}", + "create table {{ target.schema }}.db_schemas ( db text, schema text )", + "insert into {{ target.schema }}.db_schemas (db, schema) values {% for db, schema in database_schemas %}('{{ db }}', '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %}", ], 'seeds': { 'quote_columns': False, @@ -73,6 +76,12 @@ def assert_used_schemas(self): self.assertEqual(len(results), 1) self.assertEqual(results[0][0], self.unique_schema()) + db_schemas_query = 'select * from {}.db_schemas'.format(self.unique_schema()) + results = self.run_sql(db_schemas_query, fetch='all') + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], self.default_database) + self.assertEqual(results[0][1], self.unique_schema()) + def check_hooks(self, state): ctx = self.get_ctx_vars(state)