Skip to content

Commit

Permalink
AIP-72: Port task success overtime to the Supervisor
Browse files Browse the repository at this point in the history
This PR ports the overtime feature on `LocalTaskJob` (added in #39890) to the Supervisor.
It allows to terminate Task process to terminate when it exceeding the configured success overtime threshold which is useful when we add Listenener to the Task process.

closes #44356

Also added `TaskState` to update state and send end_date from task process to the supervisor.
  • Loading branch information
kaxil committed Dec 3, 2024
1 parent a242ff6 commit 4be85f7
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 17 deletions.
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from __future__ import annotations

from datetime import datetime
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -101,6 +102,7 @@ class TaskState(BaseModel):
"""

state: TerminalTIState
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


Expand Down
37 changes: 30 additions & 7 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
GetVariable,
GetXCom,
StartupDetails,
TaskState,
ToSupervisor,
)

Expand Down Expand Up @@ -265,9 +266,9 @@ class WatchedSubprocess:
client: Client

_process: psutil.Process
_exit_code: int | None = None
_terminal_state: str | None = None
_final_state: str | None = None
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand All @@ -277,6 +278,13 @@ class WatchedSubprocess:
# does not hang around forever.
failed_heartbeats: int = attrs.field(default=0, init=False)

# Maximum possible time (in seconds) that task will have for execution of auxiliary processes
# like listeners after task is marked as success.
# TODO: This should come from airflow.cfg: [core] task_success_overtime
task_success_overtime_threshold: float = attrs.field(default=20.0, init=False)
_overtime: float = attrs.field(default=0.0, init=False)
_task_end_datetime: datetime | None = attrs.field(default=None, init=False)

selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector)

procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -500,6 +508,20 @@ def _monitor_subprocess(self):

self._send_heartbeat_if_needed()

self._handle_task_overtime_if_needed()

def _handle_task_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the configured success overtime."""
if self._terminal_state != TerminalTIState.SUCCESS:
return

now = datetime.now(tz=timezone.utc)
self._overtime = (now - (self._task_end_datetime or now)).total_seconds()

if self._overtime > self.task_success_overtime_threshold:
log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id)
self.kill(signal.SIGTERM, force=True)

def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False):
"""
Service subprocess events by processing socket activity and checking for process exit.
Expand Down Expand Up @@ -631,9 +653,11 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
if isinstance(msg, GetConnection):
resp = None
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_datetime = msg.end_date
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
Expand All @@ -645,7 +669,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
14 changes: 10 additions & 4 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@

import os
import sys
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, TextIO

import attrs
import structlog
from pydantic import ConfigDict, TypeAdapter

from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -158,11 +159,14 @@ def run(ti: RuntimeTaskInstance, log: Logger):
if TYPE_CHECKING:
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)

msg: ToSupervisor | None = None
try:
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
next_method = defer.method_name
Expand All @@ -173,9 +177,8 @@ def run(ti: RuntimeTaskInstance, log: Logger):
next_method=next_method,
trigger_timeout=timeout,
)
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
except AirflowSkipException:
...
msg = TaskState(state=TerminalTIState.SKIPPED)
except AirflowRescheduleException:
...
except (AirflowFailException, AirflowSensorTimeout):
Expand All @@ -189,6 +192,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle TI handle failure
raise

if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def finalize(log: Logger): ...

Expand Down
104 changes: 102 additions & 2 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, tim
"logger": "supervisor",
} in captured_logs

def test_supervisor_handles_already_running_task(self):
def test_supervisor_handles_already_running_task(self, mocker):
"""Test that Supervisor prevents starting a Task Instance that is already running."""
ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)

Expand All @@ -328,6 +328,10 @@ def handle_request(request: httpx.Request) -> httpx.Response:

client = make_client(transport=httpx.MockTransport(handle_request))

# Patch the kill method so we can assert it was called with the correct signal
# without waiting for escalation_delay seconds
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")

with pytest.raises(ServerResponseError, match="Server returned error") as err:
WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)

Expand All @@ -337,6 +341,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}
mock_kill.assert_called_once_with(signal.SIGKILL)

@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"])
def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker):
Expand Down Expand Up @@ -478,6 +483,101 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t
"timestamp": mocker.ANY,
} in captured_logs

@pytest.mark.parametrize(
["terminal_state", "task_end_datetime", "overtime_threshold", "expected_kill"],
[
# The current date is fixed at tz.datetime(2024, 12, 1, 10, 10, 20)
# Current time minus 5 seconds | Threshold: 10s
pytest.param(
None,
tz.datetime(2024, 12, 1, 10, 10, 15),
10,
False,
id="no_terminal_state",
),
# Terminal state is not SUCCESS, while we are above the threshold, it should not kill the process
pytest.param(
TerminalTIState.SKIPPED,
tz.datetime(2024, 12, 1, 10, 10, 0),
1,
False,
id="non_success_state",
),
# Current time minus 5 seconds | Threshold: 10s
pytest.param(
TerminalTIState.SUCCESS,
tz.datetime(2024, 12, 1, 10, 10, 15),
10,
False,
id="below_threshold",
),
# Current time minus 10 seconds | Threshold: 9s
pytest.param(
TerminalTIState.SUCCESS,
tz.datetime(2024, 12, 1, 10, 10, 10),
9,
True,
id="above_threshold",
),
# End datetime is None | Threshold: 20s
pytest.param(
TerminalTIState.SUCCESS,
None,
20,
False,
id="task_end_datetime_none",
),
],
)
def test_overtime_handling(
self,
mocker,
terminal_state,
task_end_datetime,
overtime_threshold,
expected_kill,
time_machine,
):
"""Test handling of overtime under various conditions."""
# Mocking logger since we are only interested that it is called with the expected message
# and not the actual log output
mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log")

# Mock the kill method at the class level so we can assert it was called with the correct signal
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")

mock_watched_subprocess = WatchedSubprocess(
ti_id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)

# Fix the current datetime
instant = tz.datetime(2024, 12, 1, 10, 10, 20)
time_machine.move_to(instant, tick=False)

# Set the terminal state and task end datetime
mock_watched_subprocess._terminal_state = terminal_state
mock_watched_subprocess._task_end_datetime = task_end_datetime
mock_watched_subprocess.task_success_overtime_threshold = overtime_threshold

# Call `wait` to trigger the overtime handling
# This will call the `kill` method if the task has been running for too long
mock_watched_subprocess.wait()

# Validate process kill behavior and log messages
if expected_kill:
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.warning.assert_called_once_with(
"Task success overtime reached; terminating process",
ti_id=TI_ID,
)
else:
mock_kill.assert_not_called()
mock_logger.warning.assert_not_called()


class TestWatchedSubprocessKill:
@pytest.fixture
Expand Down
20 changes: 16 additions & 4 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from uuid6 import uuid7

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run
from airflow.utils import timezone

Expand Down Expand Up @@ -78,7 +78,7 @@ def test_parse(test_dags_dir: Path):
assert isinstance(ti.task.dag, DAG)


def test_run_basic(test_dags_dir: Path):
def test_run_basic(test_dags_dir: Path, time_machine):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
Expand All @@ -87,7 +87,19 @@ def test_run_basic(test_dags_dir: Path):
)

ti = parse(what)
run(ti, log=mock.MagicMock())

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.send_request = mock.Mock()
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
)


def test_run_deferred_basic(test_dags_dir: Path, time_machine):
Expand Down

0 comments on commit 4be85f7

Please # to comment.