Skip to content

Commit

Permalink
run hooks outside of a transaction (#510)
Browse files Browse the repository at this point in the history
* make it possible to issue hooks outside of a transaction

* wip

* fix incremental materializations

* hook contract for on-run-start/on-run-end

* make on-run-* hooks work more sanely

* pep8

* make codeclimate happy(-ier)

* typo

* fix for bq commit signature
  • Loading branch information
drewbanin committed Aug 29, 2017
1 parent 7a5670b commit cda3a68
Show file tree
Hide file tree
Showing 19 changed files with 249 additions and 112 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def begin(cls, profile, name='master'):
pass

@classmethod
def commit(cls, connection):
def commit(cls, profile, connection):
pass

@classmethod
Expand Down
20 changes: 15 additions & 5 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class DefaultAdapter(object):
"truncate",
"add_query",
"expand_target_column_types",
"quote_schema_and_table",
]

raw_functions = [
"get_status",
"get_result_from_cursor",
"quote",
"quote_schema_and_table",
]

###
Expand Down Expand Up @@ -396,6 +396,10 @@ def reload(cls, connection):
def add_begin_query(cls, profile, name):
return cls.add_query(profile, 'BEGIN', name, auto_begin=False)

@classmethod
def add_commit_query(cls, profile, name):
return cls.add_query(profile, 'COMMIT', name, auto_begin=False)

@classmethod
def begin(cls, profile, name='master'):
global connections_in_use
Expand Down Expand Up @@ -428,10 +432,10 @@ def commit_if_has_connection(cls, profile, name):

connection = cls.get_connection(profile, name, False)

return cls.commit(connection)
return cls.commit(profile, connection)

@classmethod
def commit(cls, connection):
def commit(cls, profile, connection):
global connections_in_use

if dbt.flags.STRICT_MODE:
Expand All @@ -445,7 +449,7 @@ def commit(cls, connection):
'it does not have one open!'.format(connection.get('name')))

logger.debug('On {}: COMMIT'.format(connection.get('name')))
connection.get('handle').commit()
cls.add_commit_query(profile, connection.get('name'))

connection['transaction_open'] = False
connections_in_use[connection.get('name')] = connection
Expand Down Expand Up @@ -512,6 +516,12 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True):

return connection, cursor

@classmethod
def clear_transaction(cls, profile, conn_name='master'):
conn = cls.begin(profile, conn_name)
cls.commit(profile, conn)
return conn_name

@classmethod
def execute_one(cls, profile, sql, model_name=None, auto_begin=False):
cls.get_connection(profile, model_name)
Expand Down Expand Up @@ -576,6 +586,6 @@ def quote(cls, identifier):
return '"{}"'.format(identifier)

@classmethod
def quote_schema_and_table(cls, profile, schema, table):
def quote_schema_and_table(cls, profile, schema, table, model_name=None):
return '{}.{}'.format(cls.quote(schema),
cls.quote(table))
4 changes: 2 additions & 2 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def exception_handler(cls, profile, sql, model_name=None,
except psycopg2.DatabaseError as e:
logger.debug('Postgres error: {}'.format(str(e)))

cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.DatabaseException(
dbt.compat.to_string(e).strip())

except Exception as e:
logger.debug("Error running SQL: %s", sql)
logger.debug("Rolling back transaction.")
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.RuntimeException(e)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def drop(cls, profile, relation, relation_type, model_name=None):
connection = cls.get_connection(profile, model_name)

if connection.get('transaction_open'):
cls.commit(connection)
cls.commit(profile, connection)

cls.begin(profile, connection.get('name'))

to_return = super(PostgresAdapter, cls).drop(
profile, relation, relation_type, model_name)

cls.commit(connection)
cls.commit(profile, connection)
cls.begin(profile, connection.get('name'))

return to_return
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def exception_handler(cls, profile, sql, model_name=None,
if 'Empty SQL statement' in msg:
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in msg:
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(msg))
else:
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.DatabaseException(msg)
except Exception as e:
logger.debug("Error running SQL: %s", sql)
logger.debug("Rolling back transaction.")
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.RuntimeException(e.msg)

@classmethod
Expand Down
27 changes: 15 additions & 12 deletions dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@
import voluptuous

from dbt.adapters.factory import get_adapter
from dbt.compat import basestring
from dbt.compat import basestring, to_string

import dbt.clients.jinja
import dbt.flags
import dbt.schema
import dbt.tracking
import dbt.utils

from dbt.logger import GLOBAL_LOGGER as logger # noqa


def get_hooks(model, context, hook_key):
hooks = model.get('config', {}).get(hook_key, [])
import dbt.hooks

if isinstance(hooks, basestring):
hooks = [hooks]

return hooks
from dbt.logger import GLOBAL_LOGGER as logger # noqa


class DatabaseWrapper(object):
Expand Down Expand Up @@ -227,6 +220,15 @@ def fn(string):
return fn


def fromjson(node):
def fn(string, default=None):
try:
return json.loads(string)
except ValueError as e:
return default
return fn


def generate(model, project, flat_graph, provider=None):
"""
Not meant to be called directly. Call with either:
Expand All @@ -248,8 +250,8 @@ def generate(model, project, flat_graph, provider=None):
context = {'env': target}
schema = profile.get('schema', 'public')

pre_hooks = get_hooks(model, context, 'pre-hook')
post_hooks = get_hooks(model, context, 'post-hook')
pre_hooks = model.get('config', {}).get('pre-hook')
post_hooks = model.get('config', {}).get('post-hook')

db_wrapper = DatabaseWrapper(model, adapter, profile)

Expand All @@ -270,6 +272,7 @@ def generate(model, project, flat_graph, provider=None):
"schema": schema,
"sql": model.get('injected_sql'),
"sql_now": adapter.date_function(),
"fromjson": fromjson(model),
"target": target,
"this": dbt.utils.This(
schema,
Expand Down
12 changes: 10 additions & 2 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@

from dbt.logger import GLOBAL_LOGGER as logger # noqa

hook_contract = Schema({
Required('sql'): basestring,
Required('transaction'): bool,
})

config_contract = Schema({
Required('enabled'): bool,
Required('materialized'): basestring,
Required('post-hook'): list,
Required('pre-hook'): list,
Required('post-hook'): [hook_contract],
Required('pre-hook'): [hook_contract],
Required('vars'): dict,
}, extra=ALLOW_EXTRA)

Expand Down Expand Up @@ -69,6 +73,10 @@
})


def validate_hook(hook):
validate_with(hook_contract, hooks)


def validate_nodes(parsed_nodes):
validate_with(parsed_nodes_contract, parsed_nodes)

Expand Down
9 changes: 1 addition & 8 deletions dbt/graph/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import networkx as nx
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.utils import is_enabled, get_materialization
from dbt.utils import is_enabled, get_materialization, coalesce
from dbt.node_types import NodeType

SELECTOR_PARENTS = '+'
Expand Down Expand Up @@ -43,13 +43,6 @@ def parse_spec(node_spec):
}


def coalesce(*args):
for arg in args:
if arg is not None:
return arg
return None


def get_package_names(graph):
return set([node.split(".")[1] for node in graph.nodes()])

Expand Down
40 changes: 40 additions & 0 deletions dbt/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

import json
from dbt.compat import to_string


class ModelHookType:
PreHook = 'pre-hook'
PostHook = 'post-hook'
Both = [PreHook, PostHook]


def _parse_hook_to_dict(hook_string):
try:
hook_dict = json.loads(hook_string)
except ValueError as e:
hook_dict = {"sql": hook_string}

if 'transaction' not in hook_dict:
hook_dict['transaction'] = True

return hook_dict


def get_hook_dict(hook):
if isinstance(hook, dict):
hook_dict = hook
else:
hook_dict = _parse_hook_to_dict(to_string(hook))

return hook_dict


def get_hooks(model, hook_key):
hooks = model.get('config', {}).get(hook_key, [])

if not isinstance(hooks, (list, tuple)):
hooks = [hooks]

wrapped = [get_hook_dict(hook) for hook in hooks]
return wrapped
4 changes: 2 additions & 2 deletions dbt/include/global_project/macros/core.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% macro statement(name=None, fetch_result=False) -%}
{% macro statement(name=None, fetch_result=False, auto_begin=True) -%}
{%- if execute: -%}
{%- set sql = render(caller()) -%}

Expand All @@ -7,7 +7,7 @@
{{ write(sql) }}
{%- endif -%}

{%- set _, cursor = adapter.add_query(sql) -%}
{%- set _, cursor = adapter.add_query(sql, auto_begin=auto_begin) -%}
{%- if name is not none -%}
{%- set result = [] if not fetch_result else adapter.get_result_from_cursor(cursor) -%}
{{ store_result(name, status=adapter.get_status(cursor), data=result) }}
Expand Down
28 changes: 24 additions & 4 deletions dbt/include/global_project/macros/materializations/helpers.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% macro run_hooks(hooks) %}
{% for hook in hooks %}
{% call statement() %}
{{ hook }};
{% macro run_hooks(hooks, inside_transaction=True) %}
{% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %}
{% call statement(auto_begin=inside_transaction) %}
{{ hook.get('sql') }}
{% endcall %}
{% endfor %}
{% endmacro %}
Expand All @@ -21,6 +21,26 @@
{% endmacro %}


{% macro make_hook_config(sql, inside_transaction) %}
{{ {"sql": sql, "transaction": inside_transaction} | tojson }}
{% endmacro %}


{% macro before_begin(sql) %}
{{ make_hook_config(sql, inside_transaction=False) }}
{% endmacro %}


{% macro in_transaction(sql) %}
{{ make_hook_config(sql, inside_transaction=True) }}
{% endmacro %}


{% macro after_commit(sql) %}
{{ make_hook_config(sql, inside_transaction=False) }}
{% endmacro %}


{% macro drop_if_exists(existing, name) %}
{% set existing_type = existing.get(name) %}
{% if existing_type is not none %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
{{ adapter.drop(identifier, existing_type) }}
{%- endif %}

{{ run_hooks(pre_hooks) }}
{{ run_hooks(pre_hooks, inside_transaction=False) }}

-- `BEGIN` happens here:
{{ run_hooks(pre_hooks, inside_transaction=True) }}

-- build model
{% if force_create or not adapter.already_exists(schema, identifier) -%}
Expand Down Expand Up @@ -79,8 +82,11 @@
{% endcall %}
{%- endif %}

{{ run_hooks(post_hooks) }}
{{ run_hooks(post_hooks, inside_transaction=True) }}

-- `COMMIT` happens here
{{ adapter.commit() }}

{{ run_hooks(post_hooks, inside_transaction=False) }}

{%- endmaterialization %}
Loading

0 comments on commit cda3a68

Please sign in to comment.