From 08bbf89225b9b87521044fa04a9666690792a4cb Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:10:08 +0100 Subject: [PATCH] FIX: Don't raise a warning in ExecutorSafeguard when execute is called from an extended operator (#42849) (#43577) * refactor: Don't raise a warning when execute is called from an extended operator, as this should always be allowed. * refactored: Fixed import of test_utils in test_dag_run --------- Co-authored-by: David Blain (cherry picked from commit 95c46ec135349c8e8d3150d16f18ab65f8240f3e) Co-authored-by: David Blain (cherry picked from commit 2f29c57ce58dc423128a3657bbc4c3bd5bbb3de0) --- airflow/models/baseoperator.py | 11 ++++++++++- tests/models/test_baseoperatormeta.py | 24 +++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 9da4b619d7fd4..449678860f80b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -34,6 +34,7 @@ import warnings from datetime import datetime, timedelta from functools import total_ordering, wraps +from threading import local from types import FunctionType from typing import ( TYPE_CHECKING, @@ -391,6 +392,8 @@ class ExecutorSafeguard: """ test_mode = conf.getboolean("core", "unit_test_mode") + _sentinel = local() + _sentinel.callers = {} @classmethod def decorator(cls, func): @@ -398,7 +401,13 @@ def decorator(cls, func): def wrapper(self, *args, **kwargs): from airflow.decorators.base import DecoratedOperator - sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None) + sentinel_key = f"{self.__class__.__name__}__sentinel" + sentinel = kwargs.pop(sentinel_key, None) + + if sentinel: + cls._sentinel.callers[sentinel_key] = sentinel + else: + sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None) if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator): message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!" diff --git a/tests/models/test_baseoperatormeta.py b/tests/models/test_baseoperatormeta.py index 6c6567b23899e..5244e86b2c386 100644 --- a/tests/models/test_baseoperatormeta.py +++ b/tests/models/test_baseoperatormeta.py @@ -40,6 +40,11 @@ def execute(self, context: Context) -> Any: return f"Hello {self.owner}!" +class ExtendedHelloWorldOperator(HelloWorldOperator): + def execute(self, context: Context) -> Any: + return super().execute(context) + + class TestExecutorSafeguard: def setup_method(self): ExecutorSafeguard.test_mode = False @@ -49,12 +54,29 @@ def teardown_method(self, method): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.db_test - def test_executor_when_classic_operator_called_from_dag(self, dag_maker): + @patch.object(HelloWorldOperator, "log") + def test_executor_when_classic_operator_called_from_dag(self, mock_log, dag_maker): with dag_maker() as dag: HelloWorldOperator(task_id="hello_operator") dag_run = dag.test() assert dag_run.state == DagRunState.SUCCESS + mock_log.warning.assert_not_called() + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.db_test + @patch.object(HelloWorldOperator, "log") + def test_executor_when_extended_classic_operator_called_from_dag( + self, + mock_log, + dag_maker, + ): + with dag_maker() as dag: + ExtendedHelloWorldOperator(task_id="hello_operator") + + dag_run = dag.test() + assert dag_run.state == DagRunState.SUCCESS + mock_log.warning.assert_not_called() @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize(