From 83d07aac9d97c65cfeec597fcd3f1c7eea7ebd1a Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 26 Feb 2025 16:39:02 +0800 Subject: [PATCH] Sync asset and alias property changes into db I missed updating group and extra into the database when rewriting the sync process. (Before the rewrite, the asset table is always wipped and repopulated.) This makes sure any property changes are written into the database. --- airflow/dag_processing/collection.py | 10 ++++- airflow/models/dag.py | 4 +- tests/dag_processing/test_collection.py | 55 ++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 733595cca0fa7..cdf759e3906c2 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -583,7 +583,7 @@ def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self: ) return coll - def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]: + def sync_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]: # Optimization: skip all database calls if no assets were collected. if not self.assets: return {} @@ -593,6 +593,10 @@ def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]: select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets)) ) } + for key, model in orm_assets.items(): + asset = self.assets[key] + model.group = asset.group + model.extra = asset.extra orm_assets.update( ((model.name, model.uri), model) for model in asset_manager.create_assets( @@ -602,7 +606,7 @@ def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]: ) return orm_assets - def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]: + def sync_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]: # Optimization: skip all database calls if no asset aliases were collected. if not self.asset_aliases: return {} @@ -612,6 +616,8 @@ def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]: select(AssetAliasModel).where(AssetAliasModel.name.in_(self.asset_aliases)) ) } + for name, model in orm_aliases.items(): + model.group = self.asset_aliases[name].group orm_aliases.update( (model.name, model) for model in asset_manager.create_asset_aliases( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index dd372b7b6c2a1..16cbf28d64706 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1876,8 +1876,8 @@ def bulk_write_to_db( asset_op = AssetModelOperation.collect(dag_op.dags) - orm_assets = asset_op.add_assets(session=session) - orm_asset_aliases = asset_op.add_asset_aliases(session=session) + orm_assets = asset_op.sync_assets(session=session) + orm_asset_aliases = asset_op.sync_asset_aliases(session=session) session.flush() # This populates id so we can create fks in later calls. orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index e3620e473b0a6..7e7e9d3d8c339 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -50,7 +50,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger -from airflow.sdk.definitions.asset import Asset, AssetWatcher +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetWatcher from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.utils import timezone as tz from airflow.utils.session import create_session @@ -149,7 +149,7 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t dag.is_active = is_active dag.is_paused = is_paused - orm_assets = asset_op.add_assets(session=session) + orm_assets = asset_op.sync_assets(session=session) # Create AssetActive objects from assets. It is usually done in the scheduler for asset in orm_assets.values(): session.add(AssetActive.for_asset(asset)) @@ -162,6 +162,57 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t assert session.query(Trigger).count() == expected_num_triggers assert session.query(asset_trigger_association_table).count() == expected_num_triggers + def test_change_asset_property_sync_group(self, dag_maker, session): + asset = Asset("myasset", group="old_group") + with dag_maker(schedule=[asset]) as dag: + EmptyOperator(task_id="mytask") + + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_assets = asset_op.sync_assets(session=session) + assert len(orm_assets) == 1 + assert next(iter(orm_assets.values())).group == "old_group" + + # Parser should pick up group change. + asset.group = "new_group" + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_assets = asset_op.sync_assets(session=session) + assert len(orm_assets) == 1 + assert next(iter(orm_assets.values())).group == "new_group" + + def test_change_asset_property_sync_extra(self, dag_maker, session): + asset = Asset("myasset", extra={"foo": "old"}) + with dag_maker(schedule=asset) as dag: + EmptyOperator(task_id="mytask") + + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_assets = asset_op.sync_assets(session=session) + assert len(orm_assets) == 1 + assert next(iter(orm_assets.values())).extra == {"foo": "old"} + + # Parser should pick up extra change. + asset.extra = {"foo": "new"} + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_assets = asset_op.sync_assets(session=session) + assert len(orm_assets) == 1 + assert next(iter(orm_assets.values())).extra == {"foo": "new"} + + def test_change_asset_alias_property_sync_group(self, dag_maker, session): + alias = AssetAlias("myalias", group="old_group") + with dag_maker(schedule=alias) as dag: + EmptyOperator(task_id="mytask") + + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_aliases = asset_op.sync_asset_aliases(session=session) + assert len(orm_aliases) == 1 + assert next(iter(orm_aliases.values())).group == "old_group" + + # Parser should pick up group change. + alias.group = "new_group" + asset_op = AssetModelOperation.collect({dag.dag_id: dag}) + orm_aliases = asset_op.sync_asset_aliases(session=session) + assert len(orm_aliases) == 1 + assert next(iter(orm_aliases.values())).group == "new_group" + @pytest.mark.db_test class TestUpdateDagParsingResults: