From e93dbf753558b882a04c031fa7d33b2e38ebefe9 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Wed, 18 Dec 2019 12:41:01 -0500 Subject: [PATCH 1/7] feat: added refactored for integration with alembic Reworked location of some function calls to support multple purposes Added migration calls to be manually added into alembic migrations --- postgresql_audit/base.py | 96 ++++++---------------------------- postgresql_audit/migrations.py | 23 ++++++++ postgresql_audit/utils.py | 75 ++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 79 deletions(-) create mode 100644 postgresql_audit/utils.py diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index e79133a..70a097b 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,7 +1,6 @@ -import os -import string import warnings from contextlib import contextmanager +from functools import partial from weakref import WeakSet import sqlalchemy as sa @@ -18,7 +17,9 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_class_by_table -HERE = os.path.dirname(os.path.abspath(__file__)) +from postgresql_audit.utils import render_tmpl, StatementExecutor, create_audit_table, create_operators, \ + build_register_table_query + cached_statements = {} @@ -29,23 +30,6 @@ class ImproperlyConfigured(Exception): class ClassNotVersioned(Exception): pass - -class StatementExecutor(object): - def __init__(self, stmt): - self.stmt = stmt - - def __call__(self, target, bind, **kwargs): - tx = bind.begin() - bind.execute(self.stmt) - tx.commit() - - -def read_file(file_): - with open(os.path.join(HERE, file_)) as f: - s = f.read() - return s - - def assign_actor(base, cls, actor_cls): if hasattr(cls, 'actor_id'): return @@ -194,10 +178,11 @@ def __init__( ), ) self.schema_name = schema_name + self.use_statement_level_triggers = use_statement_level_triggers self.table_listeners = self.get_table_listeners() self.pending_classes = WeakSet() self.cached_ddls = {} - self.use_statement_level_triggers = use_statement_level_triggers + def get_transaction_values(self): return self.values @@ -214,70 +199,28 @@ def disable(self, session): "SET LOCAL postgresql_audit.enable_versioning = 'true'" ) - def render_tmpl(self, tmpl_name): - file_contents = read_file( - 'templates/{}'.format(tmpl_name) - ).replace('%', '%%').replace('$$', '$$$$') - tmpl = string.Template(file_contents) - context = dict(schema_name=self.schema_name) - - if self.schema_name is None: - context['schema_prefix'] = '' - context['revoke_cmd'] = '' - else: - context['schema_prefix'] = '{}.'.format(self.schema_name) - context['revoke_cmd'] = ( - 'REVOKE ALL ON {schema_prefix}activity FROM public;' - ).format(**context) - - temp = tmpl.substitute(**context) - return temp - - def create_operators(self, target, bind, **kwargs): - if bind.dialect.server_version_info < (9, 5, 0): - StatementExecutor(self.render_tmpl('operators_pre95.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (9, 6, 0): - StatementExecutor(self.render_tmpl('operators_pre96.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (10, 0): - operators_template = self.render_tmpl('operators_pre100.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - operators_template = self.render_tmpl('operators.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - - def create_audit_table(self, target, bind, **kwargs): - sql = '' - if ( - self.use_statement_level_triggers and - bind.dialect.server_version_info >= (10, 0) - ): - sql += self.render_tmpl('create_activity_stmt_level.sql') - sql += self.render_tmpl('audit_table_stmt_level.sql') - else: - sql += self.render_tmpl('create_activity_row_level.sql') - sql += self.render_tmpl('audit_table_row_level.sql') - StatementExecutor(sql)(target, bind, **kwargs) - def get_table_listeners(self): listeners = {'transaction': []} listeners['activity'] = [ ('after_create', sa.schema.DDL( - self.render_tmpl('jsonb_change_key_name.sql') + render_tmpl('jsonb_change_key_name.sql', self.schema_name) )), - ('after_create', self.create_audit_table), - ('after_create', self.create_operators) + ('after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), + ('after_create', partial(create_operators, schema_name=self.schema_name)) ] if self.schema_name is not None: listeners['transaction'] = [ ('before_create', sa.schema.DDL( - self.render_tmpl('create_schema.sql') + render_tmpl('create_schema.sql', self.schema_name) )), ('after_drop', sa.schema.DDL( - self.render_tmpl('drop_schema.sql') + render_tmpl('drop_schema.sql', self.schema_name) )), ] return listeners @@ -294,12 +237,7 @@ def audit_table(self, table, exclude_columns=None): ) ) args.append(array(exclude_columns)) - - if self.schema_name is None: - func = sa.func.audit_table - else: - func = getattr(getattr(sa.func, self.schema_name), 'audit_table') - query = sa.select([func(*args)]) + query = build_register_table_query(self.schema_name, *args) if query not in cached_statements: cached_statements[query] = StatementExecutor(query) listener = (table, 'after_create', cached_statements[query]) diff --git a/postgresql_audit/migrations.py b/postgresql_audit/migrations.py index 6fab915..56adab6 100644 --- a/postgresql_audit/migrations.py +++ b/postgresql_audit/migrations.py @@ -1,6 +1,8 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB +from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators, \ + build_register_table_query from .expressions import jsonb_change_key_name @@ -16,6 +18,27 @@ def get_activity_table(schema=None): schema=schema, ) +def init_activity_table_triggers(conn, schema_name = None, use_statement_level_triggers=True): + conn.execute(render_tmpl('jsonb_change_key_name.sql', schema_name)) + create_audit_table(None, conn, schema_name, use_statement_level_triggers) + create_operators(None, conn, schema_name) + + if schema_name: + conn.execute(render_tmpl('create_schema.sql', schema_name)) + +def rollback_create_transaction(conn, schema_name=None): + if schema_name: + conn.execute(render_tmpl('drop_schema.sql', schema_name)) + +def init_before_create_transaction(conn, schema_name=None): + if schema_name: + conn.execute(render_tmpl('create_schema.sql', schema_name)) + + +def register_table(conn, table_name, exclude_columns, schema_name=None): + sql = build_register_table_query(schema_name, table_name, exclude_columns) + conn.execute(sql) + def alter_column(conn, table, column_name, func, schema=None): """ diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py new file mode 100644 index 0000000..127e694 --- /dev/null +++ b/postgresql_audit/utils.py @@ -0,0 +1,75 @@ +import os +import string +import sqlalchemy as sa + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +class StatementExecutor(object): + def __init__(self, stmt): + self.stmt = stmt + + def __call__(self, target, bind, **kwargs): + tx = bind.begin() + bind.execute(self.stmt) + tx.commit() + +def read_file(file_): + with open(os.path.join(HERE, file_)) as f: + s = f.read() + return s + +def render_tmpl(tmpl_name, schema_name=None): + file_contents = read_file( + 'templates/{}'.format(tmpl_name) + ).replace('%', '%%').replace('$$', '$$$$') + tmpl = string.Template(file_contents) + context = dict(schema_name=schema_name) + + if schema_name is None: + context['schema_prefix'] = '' + context['revoke_cmd'] = '' + else: + context['schema_prefix'] = '{}.'.format(schema_name) + context['revoke_cmd'] = ( + 'REVOKE ALL ON {schema_prefix}activity FROM public;' + ).format(**context) + + return tmpl.substitute(**context) + + +def create_operators(target, bind, schema_name, **kwargs): + if bind.dialect.server_version_info < (9, 5, 0): + StatementExecutor(render_tmpl('operators_pre95.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (9, 6, 0): + StatementExecutor(render_tmpl('operators_pre96.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (10, 0): + operators_template = render_tmpl('operators_pre100.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + operators_template = render_tmpl('operators.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + +def create_audit_table(target, bind, schema_name, use_statement_level_triggers, **kwargs): + sql = '' + if ( + use_statement_level_triggers and + bind.dialect.server_version_info >= (10, 0) + ): + sql += render_tmpl('create_activity_stmt_level.sql', schema_name) + sql += render_tmpl('audit_table_stmt_level.sql', schema_name) + else: + sql += render_tmpl('create_activity_row_level.sql', schema_name) + sql += render_tmpl('audit_table_row_level.sql', schema_name) + StatementExecutor(sql)(target, bind, **kwargs) + + +def build_register_table_query(schema_name, *args): + if schema_name is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, schema_name), 'audit_table') + return sa.select([func(*args)]) From 43e99f45659f726633f8550319f06561e491234a Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Thu, 19 Dec 2019 12:28:09 -0500 Subject: [PATCH 2/7] feat: Modularized the different aspects of VersioningManager Session related features are not under SessionManager There is a BaseVersioningManager for use with Alembic VersioningManager and FlaskVersioningManager remain api compatible --- postgresql_audit/base.py | 307 +++++++++++++++++++++++--------------- postgresql_audit/flask.py | 8 +- 2 files changed, 187 insertions(+), 128 deletions(-) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 70a097b..3ca069f 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -148,102 +148,22 @@ def convert_callables(values): } -class VersioningManager(object): - _actor_cls = None - - def __init__( - self, - actor_cls=None, - schema_name=None, - use_statement_level_triggers=True - ): - if actor_cls is not None: - self._actor_cls = actor_cls - self.values = {} +class SessionManager(object): + def __init__(self, transaction_cls, values=None): + self.transaction_cls = transaction_cls + self.values = values or {} + self._marked_transactions = set() self.listeners = ( - ( - orm.mapper, - 'instrument_class', - self.instrument_versioned_classes - ), - ( - orm.mapper, - 'after_configured', - self.configure_versioned_classes - ), ( orm.session.Session, 'before_flush', - self.receive_before_flush, + self.before_flush, ), ) - self.schema_name = schema_name - self.use_statement_level_triggers = use_statement_level_triggers - self.table_listeners = self.get_table_listeners() - self.pending_classes = WeakSet() - self.cached_ddls = {} - def get_transaction_values(self): return self.values - @contextmanager - def disable(self, session): - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'false'" - ) - try: - yield - finally: - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'true'" - ) - - def get_table_listeners(self): - listeners = {'transaction': []} - - listeners['activity'] = [ - ('after_create', sa.schema.DDL( - render_tmpl('jsonb_change_key_name.sql', self.schema_name) - )), - ('after_create', partial( - create_audit_table, - schema_name=self.schema_name, - use_statement_level_triggers=self.use_statement_level_triggers - ) - ), - ('after_create', partial(create_operators, schema_name=self.schema_name)) - ] - if self.schema_name is not None: - listeners['transaction'] = [ - ('before_create', sa.schema.DDL( - render_tmpl('create_schema.sql', self.schema_name) - )), - ('after_drop', sa.schema.DDL( - render_tmpl('drop_schema.sql', self.schema_name) - )), - ] - return listeners - - def audit_table(self, table, exclude_columns=None): - args = [table.name] - if exclude_columns: - for column in exclude_columns: - if column not in table.c: - raise ImproperlyConfigured( - "Could not configure versioning. Table '{}'' does " - "not have a column named '{}'.".format( - table.name, column - ) - ) - args.append(array(exclude_columns)) - query = build_register_table_query(self.schema_name, *args) - if query not in cached_statements: - cached_statements[query] = StatementExecutor(query) - listener = (table, 'after_create', cached_statements[query]) - if not sa.event.contains(*listener): - sa.event.listen(*listener) - def set_activity_values(self, session): dialect = session.bind.engine.dialect table = self.transaction_cls.__table__ @@ -303,40 +223,50 @@ def is_modified(self, obj_or_session): if hasattr(entity, '__versioned__') ) - def receive_before_flush(self, session, flush_context, instances): + def before_flush(self, session, flush_context, instances): + if session.transaction in self._marked_transactions: + return + if session.transaction: + self.add_entry_and_mark_transaction(session) + + def add_entry_and_mark_transaction(self, session): if self.is_modified(session): + self._marked_transactions.add(session.transaction) self.set_activity_values(session) - def instrument_versioned_classes(self, mapper, cls): - """ - Collect versioned class and add it to pending_classes list. - - :mapper mapper: SQLAlchemy mapper object - :cls cls: SQLAlchemy declarative class - """ - if hasattr(cls, '__versioned__') and cls not in self.pending_classes: - self.pending_classes.add(cls) + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) - def configure_versioned_classes(self): - """ - Configures all versioned classes that were collected during - instrumentation process. - """ - for cls in self.pending_classes: - self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) - assign_actor(self.base, self.transaction_cls, self.actor_cls) + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) - def attach_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.listen(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.listen(self.activity_cls.__table__, *values) +class BasicVersioningManager(object): + _actor_cls = None + _session_manager_factory = partial(SessionManager, values={}) - def remove_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.remove(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.remove(self.activity_cls.__table__, *values) + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + if actor_cls is not None: + self._actor_cls = actor_cls + if session_manager_factory is not None: + self._session_manager_factory = session_manager_factory + self.values = {} + self.listeners = ( + ( + orm.mapper, + 'after_configured', + self.after_configured + ), + ) + self.schema_name = schema_name + self.use_statement_level_triggers = use_statement_level_triggers @property def actor_cls(self): @@ -362,15 +292,8 @@ def actor_cls(self): ) return self._actor_cls - def attach_listeners(self): - self.attach_table_listeners() - for listener in self.listeners: - sa.event.listen(*listener) - - def remove_listeners(self): - self.remove_table_listeners() - for listener in self.listeners: - sa.event.remove(*listener) + def after_configured(self): + assign_actor(self.base, self.transaction_cls, self.actor_cls) def activity_model_factory(self, base, transaction_cls): class Activity(activity_base(base, self.schema_name, transaction_cls)): @@ -384,6 +307,28 @@ class Transaction(transaction_base(base, self.schema_name)): return Transaction + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) + self.session_manager.attach_listeners() + + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) + self.session_manager.remove_listeners() + + @contextmanager + def disable(self, session): + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'false'" + ) + try: + yield + finally: + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'true'" + ) + def init(self, base): self.base = base self.transaction_cls = self.transaction_model_factory(base) @@ -391,7 +336,123 @@ def init(self, base): base, self.transaction_cls ) + self.session_manager = self._session_manager_factory(self.transaction_cls) self.attach_listeners() +class VersioningManager(BasicVersioningManager): + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + super().__init__( + actor_cls=actor_cls, + schema_name=schema_name, + use_statement_level_triggers=use_statement_level_triggers, + session_manager_factory=session_manager_factory + ) + self.listeners = ( + ( + orm.mapper, + 'instrument_class', + self.instrument_versioned_classes + ), + ( + orm.mapper, + 'after_configured', + self.configure_versioned_classes + ), + ) + self.table_listeners = self.get_table_listeners() + self.pending_classes = WeakSet() + self.cached_ddls = {} + + def get_table_listeners(self): + listeners = {'transaction': []} + + listeners['activity'] = [ + ('after_create', sa.schema.DDL( + render_tmpl('jsonb_change_key_name.sql', self.schema_name) + )), + ('after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), + ('after_create', partial(create_operators, schema_name=self.schema_name)) + ] + if self.schema_name is not None: + listeners['transaction'] = [ + ('before_create', sa.schema.DDL( + render_tmpl('create_schema.sql', self.schema_name) + )), + ('after_drop', sa.schema.DDL( + render_tmpl('drop_schema.sql', self.schema_name) + )), + ] + return listeners + + def audit_table(self, table, exclude_columns=None): + args = [table.name] + if exclude_columns: + for column in exclude_columns: + if column not in table.c: + raise ImproperlyConfigured( + "Could not configure versioning. Table '{}'' does " + "not have a column named '{}'.".format( + table.name, column + ) + ) + args.append(array(exclude_columns)) + query = build_register_table_query(self.schema_name, *args) + if query not in cached_statements: + cached_statements[query] = StatementExecutor(query) + listener = (table, 'after_create', cached_statements[query]) + if not sa.event.contains(*listener): + sa.event.listen(*listener) + + def instrument_versioned_classes(self, mapper, cls): + """ + Collect versioned class and add it to pending_classes list. + + :mapper mapper: SQLAlchemy mapper object + :cls cls: SQLAlchemy declarative class + """ + if hasattr(cls, '__versioned__') and cls not in self.pending_classes: + self.pending_classes.add(cls) + + def configure_versioned_classes(self): + """ + Configures all versioned classes that were collected during + instrumentation process. + """ + for cls in self.pending_classes: + self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) + assign_actor(self.base, self.transaction_cls, self.actor_cls) + + def attach_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.listen(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.listen(self.activity_cls.__table__, *values) + + def remove_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.remove(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.remove(self.activity_cls.__table__, *values) + + def attach_listeners(self): + self.attach_table_listeners() + super().attach_listeners() + + def remove_listeners(self): + self.remove_table_listeners() + super().remove_listeners() + + versioning_manager = VersioningManager() diff --git a/postgresql_audit/flask.py b/postgresql_audit/flask.py index 8c788ed..049b68e 100644 --- a/postgresql_audit/flask.py +++ b/postgresql_audit/flask.py @@ -6,12 +6,10 @@ from flask import g, request from flask.globals import _app_ctx_stack, _request_ctx_stack -from .base import VersioningManager as BaseVersioningManager +from .base import VersioningManager, SessionManager -class VersioningManager(BaseVersioningManager): - _actor_cls = 'User' - +class FlaskSessionManager(SessionManager): def get_transaction_values(self): values = copy(self.values) if context_available() and hasattr(g, 'activity_values'): @@ -65,4 +63,4 @@ def activity_values(**values): del g.activity_values -versioning_manager = VersioningManager() +versioning_manager = VersioningManager(actor_cls="User", session_manager_factory=FlaskSessionManager) From eeab5de24718d7763fc6ffa560d7efc0730e2714 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 5 Jan 2020 16:14:38 -0500 Subject: [PATCH 3/7] feature: reworked migrations to allow autogeneration with alembic --- postgresql_audit/alembic/__init__.py | 130 ++++++++++++++++++ .../alembic/init_activity_table_triggers.py | 97 +++++++++++++ postgresql_audit/alembic/migration_ops.py | 71 ++++++++++ .../register_table_for_version_tracking.py | 75 ++++++++++ postgresql_audit/base.py | 6 +- postgresql_audit/migrations.py | 23 ---- 6 files changed, 376 insertions(+), 26 deletions(-) create mode 100644 postgresql_audit/alembic/__init__.py create mode 100644 postgresql_audit/alembic/init_activity_table_triggers.py create mode 100644 postgresql_audit/alembic/migration_ops.py create mode 100644 postgresql_audit/alembic/register_table_for_version_tracking.py diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py new file mode 100644 index 0000000..0314043 --- /dev/null +++ b/postgresql_audit/alembic/__init__.py @@ -0,0 +1,130 @@ +import re +from itertools import groupby + +from alembic.autogenerate import comparators, rewriter +from alembic.operations import ops + +from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \ + RemoveActivityTableTriggersOp +from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp +from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \ + DeregisterTableForVersionTrackingOp + + +@comparators.dispatch_for("schema") +def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): + routines = set() + for sch in schemas: + schema_name = autogen_context.dialect.default_schema_name if sch is None else sch + routines.update([ + (sch, *row) for row in autogen_context.connection.execute( + "select routine_name, routine_definition from information_schema.routines " + f"where routines.specific_schema='{schema_name}' " + )]) + + for sch in schemas: + should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch) + schema_prefix = f"{sch}." if sch else "" + + a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None) + a = list(a) if a else [] + if should_track_versions: + if f"{schema_prefix}audit_table" not in (x[1] for x in a): + upgrade_ops.ops.append( + InitActivityTableTriggersOp(False, schema=sch) + ) + else: + if f"{schema_prefix}audit_table" in (x[1] for x in a): + upgrade_ops.ops.append( + RemoveActivityTableTriggersOp(schema=sch) + ) + + +@comparators.dispatch_for("table") +def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table): + if metadata_table is None: + return + meta_info = metadata_table.info or {} + # TODO: Query triggers on the table + schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname + + triggers = [row for row in autogen_context.connection.execute(f""" +select event_object_schema as table_schema, + event_object_table as table_name, + trigger_schema, + trigger_name, + string_agg(event_manipulation, ',') as event, + action_timing as activation, + action_condition as condition, + action_statement as definition +from information_schema.triggers +where event_object_table = '{tablename}' and trigger_schema = '{schema_name}' +group by 1,2,3,4,6,7,8 +order by table_schema, table_name; + """)] + + trigger_name = "audit_trigger" + + if "versioned" in meta_info: + excluded_columns = metadata_table.info["versioned"].get("exclude", tuple()) + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger and set(original_excluded_columns) == set(excluded_columns): + return + + modify_ops.ops.insert(0, + RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name) + ) + else: + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger: + modify_ops.ops.append( + DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name) + ) + + +def __get_existing_excluded_columns(trigger): + original_excluded_columns = () + if trigger: + arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7]) + if arguments_match: + original_excluded_columns = arguments_match.group(1).split(",") + return original_excluded_columns + + +writer = rewriter.Rewriter() + +@writer.rewrites(ops.AddColumnOp) +def add_column_rewrite(context, revision, op): + table_info = op.column.table.info or {} + if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + AddColumnToActivityOp( + op.table_name, + op.column.name, + schema=op.column.table.schema, + ), + ] + else: + return op + +@writer.rewrites(ops.DropColumnOp) +def drop_column_rewrite(context, revision, op): + column = op._orig_column + table_info = column.table.info or {} + if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + RemoveColumnFromRemoveActivityOp( + op.table_name, + column.name, + schema=column.table.schema, + ), + ] + else: + return op + diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py new file mode 100644 index 0000000..c7d01c8 --- /dev/null +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -0,0 +1,97 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators + + +@Operations.register_operation("init_activity_table_triggers") +class InitActivityTableTriggersOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + @classmethod + def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): + op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveActivityTableTriggersOp(schema=self.schema) + +@Operations.register_operation("remove_activity_table_triggers") +class RemoveActivityTableTriggersOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + + @classmethod + def remove_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): + op = RemoveActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + + +@Operations.implementation_for(InitActivityTableTriggersOp) +def init_activity_table_triggers(operations, operation): + conn = operations.connection + + if operation.schema: + conn.execute(render_tmpl('create_schema.sql', operation.schema)) + + conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) + create_audit_table(None, conn, operation.schema, operation.use_statement_level_triggers) + create_operators(None, conn, operation.schema) + + +@Operations.implementation_for(RemoveActivityTableTriggersOp) +def remove_activity_table_triggers(operations, operation): + conn = operations.connection + bind = conn.bind + + if operation.schema: + conn.execute(render_tmpl('drop_schema.sql', operation.schema)) + + conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)") + schema_prefix = f"{operation.schema}." if operation.schema else "" + + conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])") + conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()") + + + if bind.dialect.server_version_info < (9, 5, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""") + conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""") + conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""") + if bind.dialect.server_version_info < (9, 6, 0): + conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""") + if bind.dialect.server_version_info < (10, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") + + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") + conn.execute(f"""DROP FUNCTION get_setting(text, text)""") + + +@renderers.dispatch_for(InitActivityTableTriggersOp) +def render_init_activity_table_triggers(autogen_context, op): + return "op.init_activity_table_triggers(%r, **%r)" % ( + op.use_statement_level_triggers, + {"schema": op.schema} + ) + +@renderers.dispatch_for(RemoveActivityTableTriggersOp) +def render_remove_activity_table_triggers(autogen_context, op): + return "op.remove_activity_table_triggers(**%r)" % ( + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py new file mode 100644 index 0000000..23d4b78 --- /dev/null +++ b/postgresql_audit/alembic/migration_ops.py @@ -0,0 +1,71 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit import add_column, remove_column + + +@Operations.register_operation("add_column_to_activity") +class AddColumnToActivityOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def add_column_to_activity(cls, operations, table_name, column_name, **kwargs): + op = AddColumnToActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveColumnFromRemoveActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + +@Operations.register_operation("remove_column_from_activity") +class RemoveColumnFromRemoveActivityOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def remove_column_from_activity(cls, operations, table_name, column_name, **kwargs): + op = RemoveColumnFromRemoveActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return AddColumnToActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + + +@Operations.implementation_for(AddColumnToActivityOp) +def add_column_to_activity(operations, operation): + conn = operations.connection + add_column(conn, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + + +@Operations.implementation_for(RemoveColumnFromRemoveActivityOp) +def remove_column_from_activity(operations, operation): + conn = operations.connection + remove_column(conn, operation.table_name, operation.column_name, operation.schema) + +@renderers.dispatch_for(AddColumnToActivityOp) +def render_add_column_to_activity(autogen_context, op): + return "op.add_column_to_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema, "default_value": op.default_value} + ) + +@renderers.dispatch_for(RemoveColumnFromRemoveActivityOp) +def render_remove_column_from_activitys(autogen_context, op): + return "op.remove_column_from_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py new file mode 100644 index 0000000..dd6a944 --- /dev/null +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -0,0 +1,75 @@ +import sqlalchemy as sa + +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + + +@Operations.register_operation("register_for_version_tracking") +class RegisterTableForVersionTrackingOp(MigrateOperation): + """Register Table for Version Tracking""" + + def __init__(self, tablename, excluded_columns, original_excluded_columns=None, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + self.original_excluded_columns = original_excluded_columns + + @classmethod + def register_for_version_tracking(cls, operations, tablename, exclude_columns, **kwargs): + op = RegisterTableForVersionTrackingOp(tablename, exclude_columns, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return DeregisterTableForVersionTrackingOp(self.tablename, self.original_excluded_columns, schema=self.schema) + +@Operations.register_operation("deregister_for_version_tracking") +class DeregisterTableForVersionTrackingOp(MigrateOperation): + """Drop Table from Version Tracking""" + + def __init__(self, tablename, excluded_columns, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + + + @classmethod + def deregister_for_version_tracking(cls, operations, tablename, **kwargs): + op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RegisterTableForVersionTrackingOp(self.tablename, self.excluded_columns, (), schema=self.schema) + + +@Operations.implementation_for(RegisterTableForVersionTrackingOp) +def register_for_version_tracking(operations, operation): + if operation.schema is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, operation.schema), 'audit_table') + operations.execute(sa.select([func(operation.tablename, list(operation.excluded_columns))])) + + +@Operations.implementation_for(DeregisterTableForVersionTrackingOp) +def deregister_for_version_tracking(operations, operation): + operations.execute(f"drop trigger audit_trigger_insert on {operation.tablename} ") + operations.execute(f"drop trigger audit_trigger_update on {operation.tablename} ") + operations.execute(f"drop trigger audit_trigger_delete on {operation.tablename} ") + + +@renderers.dispatch_for(RegisterTableForVersionTrackingOp) +def render_register_for_version_tracking(autogen_context, op): + return "op.register_for_version_tracking(%r, %r, **%r)" % ( + op.tablename, + op.excluded_columns, + {"schema": op.schema} + ) + +@renderers.dispatch_for(DeregisterTableForVersionTrackingOp) +def render_deregister_for_version_tracking(autogen_context, op): + return "op.deregister_for_version_tracking(%r, **%r)" % ( + op.tablename, + {"schema": op.schema} + ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 3ca069f..a061438 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -207,9 +207,9 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not hasattr(obj_or_session, '__versioned__'): + if not (hasattr(obj_or_session, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None)): raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = obj_or_session.__versioned__.get('exclude', []) + excluded = getattr(obj_or_session, "__versioned__", obj_or_session.__table_args__["versioned"]).get('exclude', []) return bool( set([ column.name @@ -220,7 +220,7 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') + if hasattr(entity, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None) ) def before_flush(self, session, flush_context, instances): diff --git a/postgresql_audit/migrations.py b/postgresql_audit/migrations.py index 56adab6..6fab915 100644 --- a/postgresql_audit/migrations.py +++ b/postgresql_audit/migrations.py @@ -1,8 +1,6 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB -from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators, \ - build_register_table_query from .expressions import jsonb_change_key_name @@ -18,27 +16,6 @@ def get_activity_table(schema=None): schema=schema, ) -def init_activity_table_triggers(conn, schema_name = None, use_statement_level_triggers=True): - conn.execute(render_tmpl('jsonb_change_key_name.sql', schema_name)) - create_audit_table(None, conn, schema_name, use_statement_level_triggers) - create_operators(None, conn, schema_name) - - if schema_name: - conn.execute(render_tmpl('create_schema.sql', schema_name)) - -def rollback_create_transaction(conn, schema_name=None): - if schema_name: - conn.execute(render_tmpl('drop_schema.sql', schema_name)) - -def init_before_create_transaction(conn, schema_name=None): - if schema_name: - conn.execute(render_tmpl('create_schema.sql', schema_name)) - - -def register_table(conn, table_name, exclude_columns, schema_name=None): - sql = build_register_table_query(schema_name, table_name, exclude_columns) - conn.execute(sql) - def alter_column(conn, table, column_name, func, schema=None): """ From 708461717a7438925ae87b3f586131d2bb36399f Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 5 Jan 2020 16:30:59 -0500 Subject: [PATCH 4/7] fix: cleanup some lingering issues --- postgresql_audit/alembic/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 0314043..bb1d84a 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -36,7 +36,7 @@ def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): else: if f"{schema_prefix}audit_table" in (x[1] for x in a): upgrade_ops.ops.append( - RemoveActivityTableTriggersOp(schema=sch) + RemoveActivityTableTriggersOp(False, schema=sch) ) @@ -45,7 +45,6 @@ def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, if metadata_table is None: return meta_info = metadata_table.info or {} - # TODO: Query triggers on the table schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname triggers = [row for row in autogen_context.connection.execute(f""" From d6cf220be12545f9878f1eaa860dded39937c9bb Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Fri, 17 Jan 2020 07:37:52 -0700 Subject: [PATCH 5/7] fix: minor fixes to variables to make code executable --- postgresql_audit/alembic/__init__.py | 2 +- .../alembic/init_activity_table_triggers.py | 19 ++++++++++--------- .../register_table_for_version_tracking.py | 7 ++++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index bb1d84a..e0509da 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -30,7 +30,7 @@ def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): a = list(a) if a else [] if should_track_versions: if f"{schema_prefix}audit_table" not in (x[1] for x in a): - upgrade_ops.ops.append( + upgrade_ops.ops.insert(0, InitActivityTableTriggersOp(False, schema=sch) ) else: diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py index c7d01c8..a40364a 100644 --- a/postgresql_audit/alembic/init_activity_table_triggers.py +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -19,7 +19,7 @@ def init_activity_table_triggers(cls, operations, use_statement_level_triggers, def reverse(self): # only needed to support autogenerate - return RemoveActivityTableTriggersOp(schema=self.schema) + return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) @Operations.register_operation("remove_activity_table_triggers") class RemoveActivityTableTriggersOp(MigrateOperation): @@ -31,8 +31,8 @@ def __init__(self, use_statement_level_triggers, schema=None): @classmethod - def remove_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): - op = RemoveActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + def remove_activity_table_triggers(cls, operations, **kwargs): + op = RemoveActivityTableTriggersOp(False, **kwargs) return operations.invoke(op) def reverse(self): @@ -42,20 +42,21 @@ def reverse(self): @Operations.implementation_for(InitActivityTableTriggersOp) def init_activity_table_triggers(operations, operation): - conn = operations.connection + conn = operations + bind = conn.get_bind() if operation.schema: conn.execute(render_tmpl('create_schema.sql', operation.schema)) conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) - create_audit_table(None, conn, operation.schema, operation.use_statement_level_triggers) - create_operators(None, conn, operation.schema) + create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers) + create_operators(None, bind, operation.schema) @Operations.implementation_for(RemoveActivityTableTriggersOp) def remove_activity_table_triggers(operations, operation): - conn = operations.connection - bind = conn.bind + conn = operations + bind = conn.get_bind() if operation.schema: conn.execute(render_tmpl('drop_schema.sql', operation.schema)) @@ -78,9 +79,9 @@ def remove_activity_table_triggers(operations, operation): conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") - conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") conn.execute(f"""DROP FUNCTION get_setting(text, text)""") + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") @renderers.dispatch_for(InitActivityTableTriggersOp) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py index dd6a944..70a158a 100644 --- a/postgresql_audit/alembic/register_table_for_version_tracking.py +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -54,9 +54,10 @@ def register_for_version_tracking(operations, operation): @Operations.implementation_for(DeregisterTableForVersionTrackingOp) def deregister_for_version_tracking(operations, operation): - operations.execute(f"drop trigger audit_trigger_insert on {operation.tablename} ") - operations.execute(f"drop trigger audit_trigger_update on {operation.tablename} ") - operations.execute(f"drop trigger audit_trigger_delete on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_insert on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_update on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_delete on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_row on {operation.tablename} ") @renderers.dispatch_for(RegisterTableForVersionTrackingOp) From f768855eca59585fc0404fbd5991ea3940cbe841 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Fri, 7 Feb 2020 13:28:20 -0500 Subject: [PATCH 6/7] fix: using wroong ops --- postgresql_audit/alembic/migration_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py index 23d4b78..f22d091 100644 --- a/postgresql_audit/alembic/migration_ops.py +++ b/postgresql_audit/alembic/migration_ops.py @@ -45,8 +45,7 @@ def reverse(self): @Operations.implementation_for(AddColumnToActivityOp) def add_column_to_activity(operations, operation): - conn = operations.connection - add_column(conn, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + add_column(operations, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) @Operations.implementation_for(RemoveColumnFromRemoveActivityOp) From 05ecb1f48c5a6450d1cab43e52e83fae9ae58009 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 26 Apr 2020 17:55:25 -0400 Subject: [PATCH 7/7] fix: getting versioned info works properly not with __table_args__ --- postgresql_audit/base.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index a061438..d69715d 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,4 +1,5 @@ import warnings +from collections import Sequence from contextlib import contextmanager from functools import partial from weakref import WeakSet @@ -207,9 +208,10 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not (hasattr(obj_or_session, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None)): + version_info = self.__get_versioned_info(obj_or_session) + if not version_info: raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = getattr(obj_or_session, "__versioned__", obj_or_session.__table_args__["versioned"]).get('exclude', []) + excluded = version_info.get('exclude', []) return bool( set([ column.name @@ -220,9 +222,22 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None) + if self.__get_versioned_info(entity) ) + def __get_versioned_info(self, entity): + v_args = getattr(entity, '__versioned__', None) + if v_args: + return v_args + table_args = getattr(entity, '__table_args__', None) + if not table_args: + return None + if isinstance(table_args, Sequence): + table_args = next((x for x in iter(table_args) if isinstance(x, dict)), None) + if not table_args: + return None + return table_args.get("info", {}).get("versioned", None) + def before_flush(self, session, flush_context, instances): if session.transaction in self._marked_transactions: return