From c753ca295d72d4e3dd74b9131d3ca4c47899cd96 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:54:27 +0300 Subject: [PATCH] Prevent using trigger_rule="always" in a dynamic mapped task --- airflow/utils/task_group.py | 22 +++++++++++++--- .../dynamic-task-mapping.rst | 5 ++++ tests/decorators/test_task_group.py | 25 ++++++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 69a5d015bd426c..c32fd347c0fbcc 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -37,6 +37,7 @@ from airflow.models.taskmixin import DAGNode from airflow.serialization.enums import DagAttributeTypes from airflow.utils.helpers import validate_group_key, validate_instance_args +from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -220,10 +221,15 @@ def parent_group(self) -> TaskGroup | None: def __iter__(self): for child in self.children.values(): - if isinstance(child, TaskGroup): - yield from child - else: - yield child + yield from self._iter_child(child) + + @staticmethod + def _iter_child(child): + """Iterate over the children of this TaskGroup.""" + if isinstance(child, TaskGroup): + yield from child + else: + yield child def add(self, task: DAGNode) -> DAGNode: """ @@ -593,6 +599,14 @@ def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: super().__init__(**kwargs) self._expand_input = expand_input + def __iter__(self): + from airflow.models.abstractoperator import AbstractOperator + + for child in self.children.values(): + if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS: + raise ValueError("Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'") + yield from self._iter_child(child) + def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" from airflow.models.xcom_arg import XComArg diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst index fd7d570785434f..df74038fd2c05e 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst @@ -84,6 +84,11 @@ The grid view also provides visibility into your mapped tasks in the details pan Although we show a "reduce" task here (``sum_it``) you don't have to have one, the mapped tasks will still be executed even if they have no downstream tasks. +.. warning:: ``TriggerRule.ALWAYS`` cannot be utilized in expanded tasks + + Assigning ``trigger_rule=TriggerRule.ALWAYS`` in expanded tasks is forbidden, as expanded parameters will be undefined with the task's immediate execution. + This is enforced at the time of the DAG parsing, and will raise an error if you try to use it. + Task-generated Mapping ---------------------- diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 6120f94af3ac7a..2dab23ca38fc7b 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -22,10 +22,11 @@ import pendulum import pytest -from airflow.decorators import dag, task_group +from airflow.decorators import dag, task, task_group from airflow.models.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput, MappedArgument from airflow.operators.empty import EmptyOperator from airflow.utils.task_group import MappedTaskGroup +from airflow.utils.trigger_rule import TriggerRule def test_task_group_with_overridden_kwargs(): @@ -133,6 +134,28 @@ def tg(): assert str(ctx.value) == "no arguments to expand against" +@pytest.mark.db_test +def test_expand_fail_trigger_rule_always(dag_maker, session): + @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) + def pipeline(): + @task + def get_param(): + return ["a", "b", "c"] + + @task(trigger_rule=TriggerRule.ALWAYS) + def t1(param): + return param + + @task_group() + def tg(param): + t1(param) + + with pytest.raises( + ValueError, match="Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'" + ): + tg.expand(param=get_param()) + + def test_expand_create_mapped(): saved = {}