Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ti key lkp table for pre-migration dangling checks #23494

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 79 additions & 43 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,22 @@
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_
import sqlalchemy.exc
from sqlalchemy import (
Column,
String,
Table,
and_,
column,
exc,
func,
inspect,
or_,
select,
table,
text,
tuple_,
)
from sqlalchemy.orm.session import Session

import airflow
Expand All @@ -37,6 +52,7 @@
from airflow.jobs.base_job import BaseJob # noqa: F401
from airflow.models import ( # noqa: F401
DAG,
ID_LEN,
XCOM_RETURN_KEY,
Base,
BaseOperator,
Expand All @@ -59,12 +75,14 @@
)

# We need to add this model manually to get reset working well
from airflow.models.base import COLLATION_ARGS
from airflow.models.serialized_dag import SerializedDagModel # noqa: F401
from airflow.models.tasklog import LogTemplate
from airflow.utils import helpers

# TODO: remove create_session once we decide to break backward compatibility
from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.version import version

if TYPE_CHECKING:
Expand Down Expand Up @@ -1063,7 +1081,6 @@ def _move_dangling_data_to_new_table(
session=session,
)
session.commit()

target_table = source_table.to_metadata(source_table.metadata, name=target_table_name)
log.debug("checking whether rows were moved for table %s", target_table_name)
moved_rows_exist_query = select([1]).select_from(target_table).limit(1)
Expand All @@ -1075,6 +1092,7 @@ def _move_dangling_data_to_new_table(
# no bad rows were found; drop moved rows table.
target_table.drop(bind=session.get_bind(), checkfirst=True)
else:
# purge the bad rows
log.debug("rows moved; purging from %s", source_table.name)
if dialect_name == 'sqlite':
pk_cols = source_table.primary_key.columns
Expand All @@ -1088,6 +1106,7 @@ def _move_dangling_data_to_new_table(
)
log.debug(delete.compile())
session.execute(delete)

session.commit()

log.debug("exiting move function")
Expand All @@ -1109,7 +1128,7 @@ def _dangling_against_dag_run(session, source_table, dag_run):
)


def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
def _dangling_against_task_instance(session, source_table, ti_lkp_table):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a valid task instance (and associated dagrun).
Expand All @@ -1120,35 +1139,13 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
query logic depending on which revision the database is at.

"""
if 'run_id' not in task_instance.c:
# db is < 2.2.0
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.execution_date == task_instance.c.execution_date,
source_table.c.task_id == task_instance.c.task_id,
)
else:
# db is 2.2.0 <= version < 2.3.0
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.run_id == task_instance.c.run_id,
source_table.c.task_id == task_instance.c.task_id,
)

return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
where_clause = and_(
source_table.c.dag_id == ti_lkp_table.c.dag_id,
source_table.c.task_id == ti_lkp_table.c.task_id,
source_table.c.execution_date == ti_lkp_table.c.execution_date,
)
exists_subquery = session.query(text('1')).select_from(ti_lkp_table).filter(where_clause)
return exists_subquery


def _move_duplicate_data_to_new_table(
Expand Down Expand Up @@ -1203,6 +1200,43 @@ def _move_duplicate_data_to_new_table(
session.execute(delete)


def _create_ti_key_lkp_table(session, table_name) -> Table:
"""Creates lkp table for all valid TI keys"""
tmp_table = Table(
table_name,
Base.metadata,
Column('task_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True),
Column('dag_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True),
Column('execution_date', UtcDateTime, primary_key=True),
)
tmp_table.drop(bind=settings.engine, checkfirst=True)

log.debug("creating TI key lkp table")
Base.metadata.create_all(settings.engine, tables=[tmp_table])
log.debug("inserting TI key lkp table")
session.commit()
try:
# post 2.2
session.execute(
f"insert into {table_name} "
"select ti.task_id, ti.dag_id, dr.execution_date "
"from task_instance ti "
"join dag_run dr on dr.dag_id = ti.dag_id "
" and dr.run_id = ti.run_id "
)
except sqlalchemy.exc.OperationalError:
# pre-2.2
session.execute(
f"insert into {table_name} "
"select ti.task_id, ti.dag_id, dr.execution_date "
"from task_instance ti "
"join dag_run dr on dr.dag_id = ti.dag_id "
" and dr.execution_date = ti.execution_date "
" and dr.run_id is not null"
)
return tmp_table


def check_bad_references(session: Session) -> Iterable[str]:
"""
Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
Expand Down Expand Up @@ -1255,8 +1289,10 @@ class BadReferenceConfig:
return

existing_table_names = set(inspect(session.get_bind()).get_table_names())
errored = False

ti_lkp_table = _create_ti_key_lkp_table(session=session, table_name='_airflow_tmp_ti_key_lkp')

session.commit()
for model, change_version, bad_ref_cfg in models_list:
log.debug("checking model %s", model.__tablename__)
# We can't use the model here since it may differ from the db state due to
Expand All @@ -1269,12 +1305,15 @@ class BadReferenceConfig:
if "run_id" in source_table.columns:
continue

func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)

session.commit()
bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, ti_lkp_table=ti_lkp_table)
dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
select_list = [x.label(x.name) for x in source_table.c]
log.debug(bad_rows_subquery.selectable.compile())
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
session.commit()
if dangling_table_name in existing_table_names:
invalid_row_count = bad_rows_query.count()
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue
else:
Expand All @@ -1284,21 +1323,18 @@ class BadReferenceConfig:
invalid_count=invalid_row_count,
reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
)
errored = True
continue
continue

log.debug("moving data for table %s", source_table.name)
_move_dangling_data_to_new_table(
session,
source_table,
bad_rows_query,
invalid_rows_query,
dangling_table_name,
)

if errored:
session.rollback()
else:
session.commit()
session.commit()
ti_lkp_table.drop(bind=settings.engine, checkfirst=True)


@provide_session
Expand Down