diff --git a/agents-api/agents_api/activities/demo.py b/agents-api/agents_api/activities/demo.py index c883050c1..797ef6c90 100644 --- a/agents-api/agents_api/activities/demo.py +++ b/agents-api/agents_api/activities/demo.py @@ -1,6 +1,17 @@ from temporalio import activity +from ..env import testing + -@activity.defn async def demo_activity(a: int, b: int) -> int: + # Should throw an error if testing is not enabled + raise Exception("This should not be called in production") + + +async def mock_demo_activity(a: int, b: int) -> int: return a + b + + +demo_activity = activity.defn(name="demo_activity")( + demo_activity if not testing else mock_demo_activity +) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 9d606a041..1dd3793b0 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -3,11 +3,11 @@ from ..clients import embed as embedder from ..clients.cozo import get_cozo_client +from ..env import testing from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query from .types import EmbedDocsPayload -@activity.defn @beartype async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: indices, snippets = list(zip(*enumerate(payload.content))) @@ -30,3 +30,13 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: embeddings=embeddings, client=cozo_client or get_cozo_client(), ) + + +async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: + # Does nothing + return None + + +embed_docs = activity.defn(name="embed_docs")( + embed_docs if not testing else mock_embed_docs +) diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 6f6630f4a..14f4f3f12 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -9,9 +9,9 @@ StepContext, StepOutcome, ) +from ...env import testing -@activity.defn @beartype async def evaluate_step( context: StepContext[EvaluateStep], @@ -20,3 +20,12 @@ async def evaluate_step( output = simple_eval_dict(exprs, values=context.model_dump()) return StepOutcome(output=output) + + +# Note: This is here just for clarity. We could have just imported evaluate_step directly +# They do the same thing, so we dont need to mock the evaluate_step function +mock_evaluate_step = evaluate_step + +evaluate_step = activity.defn(name="evaluate_step")( + evaluate_step if not testing else mock_evaluate_step +) diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index b06c19657..1642075f4 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -1,18 +1,14 @@ from beartype import beartype from temporalio import activity -from ...autogen.openapi_model import ( - CreateTransitionRequest, -) -from ...common.protocol.tasks import ( - StepContext, -) +from ...autogen.openapi_model import CreateTransitionRequest +from ...common.protocol.tasks import StepContext +from ...env import testing from ...models.execution.create_execution_transition import ( create_execution_transition as create_execution_transition_query, ) -@activity.defn @beartype async def transition_step( context: StepContext, @@ -34,3 +30,16 @@ async def transition_step( data=transition_info, update_execution_status=True, ) + + +async def mock_transition_step( + context: StepContext, + transition_info: CreateTransitionRequest, +) -> None: + # Does nothing + return None + + +transition_step = activity.defn(name="transition_step")( + transition_step if not testing else mock_transition_step +) diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index f94c8715d..b339b83d1 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -5,17 +5,12 @@ from agents_api.autogen.Executions import TransitionTarget -from ...autogen.openapi_model import ( - YieldStep, -) -from ...common.protocol.tasks import ( - StepContext, - StepOutcome, -) +from ...autogen.openapi_model import YieldStep +from ...common.protocol.tasks import StepContext, StepOutcome +from ...env import testing from .utils import simple_eval_dict -@activity.defn @beartype async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]: all_workflows = context.execution_input.task.workflows @@ -36,3 +31,12 @@ async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, A ) return StepOutcome(output=arguments, transition_to=("step", transition_target)) + + +# Note: This is here just for clarity. We could have just imported yield_step directly +# They do the same thing, so we dont need to mock the yield_step function +mock_yield_step = yield_step + +yield_step = activity.defn(name="yield_step")( + yield_step if not testing else mock_yield_step +) diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 38ba2101c..afdc81e1e 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -85,10 +85,18 @@ temporal_worker_url=temporal_worker_url, temporal_namespace=temporal_namespace, embedding_model_id=embedding_model_id, + testing=testing, ) -if debug: +if debug or testing: # Print the loaded environment variables for debugging purposes. print("Environment variables:") pprint(environment) print() + + # Yell if testing is enabled + print("@" * 80) + print( + f"@@@ Running in {'testing' if testing else 'debug'} mode. This should not be enabled in production. @@@" + ) + print("@" * 80) diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index c7e5c6bd9..dd41620ae 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -29,11 +29,6 @@ async def run_embed_docs_task( client = client or (await temporal.get_client()) - # TODO: Remove this conditional once we have a way to run workflows in - # a test environment. - if testing: - return None - embed_payload = EmbedDocsPayload( developer_id=developer_id, doc_id=doc_id, @@ -49,7 +44,10 @@ async def run_embed_docs_task( id=str(job_id), ) - background_tasks.add_task(handle.result) + # TODO: Remove this conditional once we have a way to run workflows in + # a test environment. + if not testing: + background_tasks.add_task(handle.result) return handle diff --git a/agents-api/agents_api/routers/tasks/create_task_execution.py b/agents-api/agents_api/routers/tasks/create_task_execution.py index 5c40d4927..988af074f 100644 --- a/agents-api/agents_api/routers/tasks/create_task_execution.py +++ b/agents-api/agents_api/routers/tasks/create_task_execution.py @@ -1,37 +1,89 @@ import logging from typing import Annotated -from uuid import uuid4 +from uuid import UUID, uuid4 -from fastapi import Depends, HTTPException, status +from beartype import beartype +from fastapi import BackgroundTasks, Depends, HTTPException, status from jsonschema import validate from jsonschema.exceptions import ValidationError from pycozo.client import QueryException from pydantic import UUID4 from starlette.status import HTTP_201_CREATED +from temporalio.client import WorkflowHandle -from agents_api.autogen.openapi_model import ( +from ...autogen.Executions import Execution +from ...autogen.openapi_model import ( CreateExecutionRequest, ResourceCreatedResponse, UpdateExecutionRequest, ) -from agents_api.clients.temporal import run_task_execution_workflow -from agents_api.dependencies.developer_id import get_developer_id -from agents_api.models.execution.create_execution import ( +from ...clients.temporal import run_task_execution_workflow +from ...dependencies.developer_id import get_developer_id +from ...models.execution.create_execution import ( create_execution as create_execution_query, ) -from agents_api.models.execution.create_temporal_lookup import create_temporal_lookup -from agents_api.models.execution.prepare_execution_input import prepare_execution_input -from agents_api.models.execution.update_execution import ( +from ...models.execution.create_temporal_lookup import create_temporal_lookup +from ...models.execution.prepare_execution_input import prepare_execution_input +from ...models.execution.update_execution import ( update_execution as update_execution_query, ) -from agents_api.models.task.get_task import get_task as get_task_query - +from ...models.task.get_task import get_task as get_task_query from .router import router logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +@beartype +async def start_execution( + *, + developer_id: UUID, + task_id: UUID, + data: CreateExecutionRequest, + client=None, +) -> tuple[Execution, WorkflowHandle]: + execution_id = uuid4() + + execution = create_execution_query( + developer_id=developer_id, + task_id=task_id, + execution_id=execution_id, + data=data, + client=client, + ) + + execution_input = prepare_execution_input( + developer_id=developer_id, + task_id=task_id, + execution_id=execution_id, + client=client, + ) + + try: + handle = await run_task_execution_workflow( + execution_input=execution_input, + job_id=uuid4(), + ) + + except Exception as e: + logger.exception(e) + + update_execution_query( + developer_id=developer_id, + task_id=task_id, + execution_id=execution_id, + data=UpdateExecutionRequest(status="failed"), + client=client, + ) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Execution creation failed", + ) from e + + return execution, handle + + @router.post( "/tasks/{task_id}/executions", status_code=HTTP_201_CREATED, @@ -41,6 +93,7 @@ async def create_task_execution( task_id: UUID4, data: CreateExecutionRequest, x_developer_id: Annotated[UUID4, Depends(get_developer_id)], + background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: try: task = get_task_query(task_id=task_id, developer_id=x_developer_id) @@ -60,44 +113,18 @@ async def create_task_execution( raise - execution_id = uuid4() - execution = create_execution_query( + execution, handle = await start_execution( developer_id=x_developer_id, task_id=task_id, - execution_id=execution_id, data=data, ) - execution_input = prepare_execution_input( + background_tasks.add_task( + create_temporal_lookup, + # developer_id=x_developer_id, task_id=task_id, - execution_id=execution_id, - ) - - try: - handle = await run_task_execution_workflow( - execution_input=execution_input, - job_id=uuid4(), - ) - except Exception as e: - logger.exception(e) - - update_execution_query( - developer_id=x_developer_id, - task_id=task_id, - execution_id=execution_id, - data=UpdateExecutionRequest(status="failed"), - ) - - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Task creation failed", - ) - - create_temporal_lookup( - developer_id=x_developer_id, - task_id=task_id, - execution_id=execution_id, + execution_id=execution.id, workflow_handle=handle, ) diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index d658019d0..98dfc97b5 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -10,26 +10,18 @@ from .fixtures import ( cozo_client, - patch_embed_acompletion, test_developer_id, test_doc, ) from .utils import patch_testing_temporal -# from agents_api.activities.truncation import get_extra_entries -# from agents_api.autogen.openapi_model import Role -# from agents_api.common.protocol.entries import Entry - @test("activity: call direct embed_docs") async def _( cozo_client=cozo_client, developer_id=test_developer_id, doc=test_doc, - mocks=patch_embed_acompletion, ): - (embed, _) = mocks - title = "title" content = ["content 1"] include_title = True @@ -46,8 +38,6 @@ async def _( cozo_client, ) - embed.assert_called_once() - @test("activity: call demo workflow via temporal client") async def _(): diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py new file mode 100644 index 000000000..254841075 --- /dev/null +++ b/agents-api/tests/test_execution_workflow.py @@ -0,0 +1,27 @@ +# Tests for task queries + +from ward import test + +from agents_api.autogen.openapi_model import CreateExecutionRequest +from agents_api.routers.tasks.create_task_execution import start_execution + +from .fixtures import cozo_client, test_developer_id, test_task +from .utils import patch_testing_temporal + + +@test("workflow: create task execution") +async def _(client=cozo_client, developer_id=test_developer_id, task=test_task): + data = CreateExecutionRequest(input={"test": "input"}) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + mock_run_task_execution_workflow.assert_called_once()