Skip to content

Commit

Permalink
Create new databases from the ORM (#24156)
Browse files Browse the repository at this point in the history
This PR opens up to creating new databases from the ORM instead of going through the migration files.

`airflow db init` creates the new db.

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
  • Loading branch information
ephraimbuddy and ashb committed Aug 1, 2022
1 parent b90fc14 commit 5588c3f
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,10 @@ def upgrade():

def downgrade():
"""Unapply Change default ``pool_slots`` to ``1``"""
conn = op.get_bind()
if conn.dialect.name == 'mssql':
# DB created from ORM doesn't set a server_default here and MSSQL fails while trying to drop
# the non existent server_default. We ignore it for MSSQL
return
with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None)
23 changes: 20 additions & 3 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,26 @@

SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA")

metadata = (
None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA)
)
# For more information about what the tokens in the naming convention
# below mean, see:
# https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData.params.naming_convention
naming_convention = {
"ix": "idx_%(column_0_N_label)s",
"uq": "%(table_name)s_%(column_0_N_name)s_uq",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}


def _get_schema():
if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace():
return None
return SQL_ALCHEMY_SCHEMA


metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)

Base: Any = declarative_base(metadata=metadata)

ID_LEN = 250
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowDagInconsistent, AirflowException, DuplicateTaskIdFound, TaskNotFound
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.base import ID_LEN, Base
from airflow.models.base import Base, StringID
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
Expand Down Expand Up @@ -2788,7 +2788,7 @@ class DagTag(Base):
__tablename__ = "dag_tag"
name = Column(String(TAG_MAX_LEN), primary_key=True)
dag_id = Column(
String(ID_LEN),
StringID(),
ForeignKey('dag.dag_id', name='dag_tag_dag_id_fkey', ondelete='CASCADE'),
primary_key=True,
)
Expand All @@ -2804,8 +2804,8 @@ class DagModel(Base):
"""
These items are stored in the database for state related information
"""
dag_id = Column(String(ID_LEN), primary_key=True)
root_dag_id = Column(String(ID_LEN))
dag_id = Column(StringID(), primary_key=True)
root_dag_id = Column(StringID())
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
Expand Down
10 changes: 9 additions & 1 deletion airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Optional

import sqlalchemy_jsonfield
from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_, text, tuple_
from sqlalchemy import Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, and_, not_, text, tuple_
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Session, relationship

Expand All @@ -46,6 +46,14 @@ class RenderedTaskInstanceFields(Base):
k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)

__table_args__ = (
PrimaryKeyConstraint(
"dag_id",
"task_id",
"run_id",
"map_index",
name='rendered_task_instance_fields_pkey',
mssql_clustered=True,
),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
Expand Down
5 changes: 4 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
ForeignKeyConstraint,
Index,
Integer,
PrimaryKeyConstraint,
String,
and_,
false,
Expand Down Expand Up @@ -428,7 +429,6 @@ class TaskInstance(Base, LoggingMixin):
"""

__tablename__ = "task_instance"

task_id = Column(StringID(), primary_key=True, nullable=False)
dag_id = Column(StringID(), primary_key=True, nullable=False)
run_id = Column(StringID(), primary_key=True, nullable=False)
Expand Down Expand Up @@ -480,6 +480,9 @@ class TaskInstance(Base, LoggingMixin):
Index('ti_pool', pool, state, priority_weight),
Index('ti_job_id', job_id),
Index('ti_trigger_id', trigger_id),
PrimaryKeyConstraint(
"dag_id", "task_id", "run_id", "map_index", name='task_instance_pkey', mssql_clustered=True
),
ForeignKeyConstraint(
[trigger_id],
['trigger.id'],
Expand Down
14 changes: 13 additions & 1 deletion airflow/models/taskreschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datetime
from typing import TYPE_CHECKING

from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, text
from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship

Expand Down Expand Up @@ -134,3 +134,15 @@ def find_for_task_instance(task_instance, session=None, try_number=None):
return TaskReschedule.query_for_task_instance(
task_instance, session=session, try_number=try_number
).all()


@event.listens_for(TaskReschedule.__table__, "before_create")
def add_ondelete_for_mssql(table, conn, **kw):
if conn.dialect.name != "mssql":
return

for constraint in table.constraints:
if constraint.name != "task_reschedule_dr_fkey":
continue
constraint.ondelete = 'NO ACTION'
return
14 changes: 13 additions & 1 deletion airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload

import pendulum
from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String, text
from sqlalchemy import (
Column,
ForeignKeyConstraint,
Index,
Integer,
LargeBinary,
PrimaryKeyConstraint,
String,
text,
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -73,6 +82,9 @@ class BaseXCom(Base, LoggingMixin):
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
PrimaryKeyConstraint(
"dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
Expand Down
52 changes: 43 additions & 9 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,45 @@ def create_default_connections(session: Session = NEW_SESSION):
)


@provide_session
def initdb(session: Session = NEW_SESSION):
"""Initialize Airflow database."""
upgradedb(session=session)
def _create_db_from_orm(session):
from alembic import command
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS'):
create_default_connections(session=session)
from airflow.models import Base
from airflow.www.fab_security.sqla.models import Model
from airflow.www.session import AirflowDatabaseSessionInterface

def _create_flask_session_tbl():
flask_app = Flask(__name__)
flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table='session', key_prefix='')
db.create_all()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
Base.metadata.create_all(settings.engine)
Model.metadata.create_all(settings.engine)
_create_flask_session_tbl()
# stamp the migration head
config = _get_alembic_config()
command.stamp(config, "head")

from flask_appbuilder.models.sqla import Base

Base.metadata.create_all(settings.engine)
@provide_session
def initdb(session: Session = NEW_SESSION, load_connections: bool = True):
"""Initialize Airflow database."""
db_exists = _get_current_revision(session)
if db_exists:
upgradedb(session=session)
else:
_create_db_from_orm(session=session)
# Load default connections
if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS') and load_connections:
create_default_connections(session=session)
# Add default pool & sync log_template
add_default_pool_if_not_exists()
synchronize_log_template()


def _get_alembic_config():
Expand Down Expand Up @@ -1487,6 +1513,11 @@ def upgradedb(
if errors_seen:
exit(1)

if not to_revision and not _get_current_revision(session=session):
# Don't load default connections
# New DB; initialize and exit
initdb(session=session, load_connections=False)
return
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
log.info("Creating tables")
command.upgrade(config, revision=to_revision or 'heads')
Expand Down Expand Up @@ -1699,7 +1730,10 @@ def compare_type(context, inspected_column, metadata_column, inspected_type, met

if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String):
# This is a hack to get around MySQL VARCHAR collation
# not being possible to change from utf8_bin to utf8mb3_bin
# not being possible to change from utf8_bin to utf8mb3_bin.
# We only make sure lengths are the same
if inspected_type.length != metadata_type.length:
return True
return False
return None

Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,20 @@ def pytest_addoption(parser):


def initial_db_init():
from flask import Flask

from airflow.configuration import conf
from airflow.utils import db
from airflow.www.app import sync_appbuilder_roles
from airflow.www.extensions.init_appbuilder import init_appbuilder

db.resetdb()
db.bootstrap_dagbag()
# minimal app to add roles
flask_app = Flask(__name__)
flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
init_appbuilder(flask_app)
sync_appbuilder_roles(flask_app)


@pytest.fixture(autouse=True, scope="session")
Expand Down

0 comments on commit 5588c3f

Please sign in to comment.