Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Sync asset and alias property changes into db #47090

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self:
)
return coll

def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]:
def sync_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]:
# Optimization: skip all database calls if no assets were collected.
if not self.assets:
return {}
Expand All @@ -593,6 +593,10 @@ def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]:
select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets))
)
}
for key, model in orm_assets.items():
asset = self.assets[key]
model.group = asset.group
model.extra = asset.extra
orm_assets.update(
((model.name, model.uri), model)
for model in asset_manager.create_assets(
Expand All @@ -602,7 +606,7 @@ def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]:
)
return orm_assets

def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]:
def sync_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]:
# Optimization: skip all database calls if no asset aliases were collected.
if not self.asset_aliases:
return {}
Expand All @@ -612,6 +616,8 @@ def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]:
select(AssetAliasModel).where(AssetAliasModel.name.in_(self.asset_aliases))
)
}
for name, model in orm_aliases.items():
model.group = self.asset_aliases[name].group
orm_aliases.update(
(model.name, model)
for model in asset_manager.create_asset_aliases(
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,8 +1876,8 @@ def bulk_write_to_db(

asset_op = AssetModelOperation.collect(dag_op.dags)

orm_assets = asset_op.add_assets(session=session)
orm_asset_aliases = asset_op.add_asset_aliases(session=session)
orm_assets = asset_op.sync_assets(session=session)
orm_asset_aliases = asset_op.sync_asset_aliases(session=session)
session.flush() # This populates id so we can create fks in later calls.

orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date.
Expand Down
55 changes: 53 additions & 2 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetWatcher
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t
dag.is_active = is_active
dag.is_paused = is_paused

orm_assets = asset_op.add_assets(session=session)
orm_assets = asset_op.sync_assets(session=session)
# Create AssetActive objects from assets. It is usually done in the scheduler
for asset in orm_assets.values():
session.add(AssetActive.for_asset(asset))
Expand All @@ -162,6 +162,57 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t
assert session.query(Trigger).count() == expected_num_triggers
assert session.query(asset_trigger_association_table).count() == expected_num_triggers

def test_change_asset_property_sync_group(self, dag_maker, session):
asset = Asset("myasset", group="old_group")
with dag_maker(schedule=[asset]) as dag:
EmptyOperator(task_id="mytask")

asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_assets = asset_op.sync_assets(session=session)
assert len(orm_assets) == 1
assert next(iter(orm_assets.values())).group == "old_group"

# Parser should pick up group change.
asset.group = "new_group"
asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_assets = asset_op.sync_assets(session=session)
assert len(orm_assets) == 1
assert next(iter(orm_assets.values())).group == "new_group"

def test_change_asset_property_sync_extra(self, dag_maker, session):
asset = Asset("myasset", extra={"foo": "old"})
with dag_maker(schedule=asset) as dag:
EmptyOperator(task_id="mytask")

asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_assets = asset_op.sync_assets(session=session)
assert len(orm_assets) == 1
assert next(iter(orm_assets.values())).extra == {"foo": "old"}

# Parser should pick up extra change.
asset.extra = {"foo": "new"}
asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_assets = asset_op.sync_assets(session=session)
assert len(orm_assets) == 1
assert next(iter(orm_assets.values())).extra == {"foo": "new"}

def test_change_asset_alias_property_sync_group(self, dag_maker, session):
alias = AssetAlias("myalias", group="old_group")
with dag_maker(schedule=alias) as dag:
EmptyOperator(task_id="mytask")

asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_aliases = asset_op.sync_asset_aliases(session=session)
assert len(orm_aliases) == 1
assert next(iter(orm_aliases.values())).group == "old_group"

# Parser should pick up group change.
alias.group = "new_group"
asset_op = AssetModelOperation.collect({dag.dag_id: dag})
orm_aliases = asset_op.sync_asset_aliases(session=session)
assert len(orm_aliases) == 1
assert next(iter(orm_aliases.values())).group == "new_group"


@pytest.mark.db_test
class TestUpdateDagParsingResults:
Expand Down