Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add trigger tests #32

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codespell-ignore-words.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
assertIn
asend
153 changes: 151 additions & 2 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

from airflow.exceptions import AirflowNotFoundException
from anyscale.job.models import JobConfig, JobState, JobStatus
from airflow.triggers.base import TriggerEvent
from anyscale.job.models import JobConfig, JobRunStatus, JobState, JobStatus
from anyscale.service.models import ServiceState

from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger
Expand Down Expand Up @@ -141,6 +142,77 @@ async def test_anyscale_run_trigger(self, mocked_sleep, mocked_get_job_logs, moc
self.assertEqual(result.payload["message"], "Job 1234 completed with status JobState.SUCCEEDED.")
self.assertEqual(result.payload["job_id"], "1234")

@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_success(self, mock_terminal_state, mock_hook):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job", state=JobState.SUCCEEDED, name="", config=JobConfig(entrypoint="122"), runs=[]
)

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.asyncio.get_event_loop")
@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_success_fetch_log(self, mock_terminal_state, mock_hook, mock_asyncio_loop):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=True)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job",
state=JobState.SUCCEEDED,
name="",
config=JobConfig(entrypoint="122"),
runs=[JobRunStatus(name="test", state="SUCCEEDED")],
)
mock_loop = AsyncMock()
mock_asyncio_loop.return_value = mock_loop
mock_loop.run_in_executor.side_effect = "hello\n"

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"}
)
mock_asyncio_loop.assert_called_once()
mock_loop.run_in_executor.assert_called_once()
mock_loop.run_in_executor.return_value = []

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_error(self, mock_terminal_state, mock_hook):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job", state=JobState.FAILED, name="", config=JobConfig(entrypoint="122"), runs=[]
)

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "FAILED", "message": "Job test_job completed with state FAILED.", "job_id": "test_job"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())


class TestAnyscaleServiceTrigger(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -285,6 +357,83 @@ def test_get_current_status_canary_100_percent(self, mock_get_service_status):
# Ensure the mock was called correctly
mock_get_service_status.assert_called_once_with("AstroService")

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_success(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.return_value = ServiceState.RUNNING

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{
"state": ServiceState.RUNNING,
"message": "Service deployment succeeded",
"service_name": "AstroService",
}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_failure(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.return_value = ServiceState.UNKNOWN

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{
"state": ServiceState.SYSTEM_FAILURE,
"message": "Service AstroService entered an unexpected state: UNKNOWN",
"service_name": "AstroService",
}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_service_exception(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.side_effect = Exception("Unknown error")

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": ServiceState.SYSTEM_FAILURE, "message": "Unknown error", "service_name": "AstroService"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())


if __name__ == "__main__":
unittest.main()
Loading