diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 8904ce01e..85f94069e 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -12,7 +12,7 @@ from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from ..autogen.openapi_model import TransitionTarget -from ..common.interceptors import offload_if_large +from ..common.interceptors import CustomClientInterceptor, offload_if_large from ..common.protocol.tasks import ExecutionInput from ..common.retry_policies import DEFAULT_RETRY_POLICY from ..env import ( @@ -49,6 +49,7 @@ async def get_client( worker_url, namespace=namespace, tls=tls_config, + interceptors=[CustomClientInterceptor()], data_converter=data_converter, api_key=temporal_api_key or None, rpc_metadata=rpc_metadata, diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 66c69a26e..6cc73dd58 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -13,6 +13,14 @@ from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError +from temporalio.client import ( + Interceptor as ClientInterceptor, +) +from temporalio.client import ( + OutboundInterceptor, + StartWorkflowInput, + WorkflowHandle, +) from temporalio.exceptions import ActivityError, ApplicationError, FailureError, TemporalError from temporalio.service import RPCError from temporalio.worker import ( @@ -285,7 +293,7 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None: To add a custom outbound interceptor, wrap the given interceptor before sending to the next ``init`` call. """ - self.next.init(CustomOutboundInterceptor(outbound)) + self.next.init(CustomWorkflowOutboundInterceptor(outbound)) @offload_to_blob_store async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: @@ -298,13 +306,14 @@ async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: ) -class CustomOutboundInterceptor(WorkflowOutboundInterceptor): +class CustomWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): """ Custom outbound interceptor for Temporal workflows. """ - @offload_to_blob_store + # @offload_to_blob_store def start_activity(self, input: StartActivityInput) -> ActivityHandle: + input.args = [offload_if_large(arg) for arg in input.args] return handle_execution_with_errors_sync( super().start_activity, input, @@ -324,8 +333,9 @@ def start_local_activity(self, input: StartLocalActivityInput) -> ActivityHandle input, ) - @offload_to_blob_store + # @offload_to_blob_store async def start_child_workflow(self, input: StartChildWorkflowInput) -> ChildWorkflowHandle: + input.args = [offload_if_large(arg) for arg in input.args] return await handle_execution_with_errors( super().start_child_workflow, input, @@ -352,3 +362,29 @@ def workflow_interceptor_class( Returns the custom workflow interceptor class. """ return CustomWorkflowInterceptor + + +class CustomClientInterceptor(ClientInterceptor): + """ + Custom interceptor for Temporal. + """ + + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return CustomOutboundInterceptor(super().intercept_client(next)) + + +class CustomOutboundInterceptor(OutboundInterceptor): + """ + Custom outbound interceptor for Temporal workflows. + """ + + # @offload_to_blob_store + async def start_workflow(self, input: StartWorkflowInput) -> WorkflowHandle[Any, Any]: + """ + interceptor for outbound workflow calls + """ + input.args = [offload_if_large(arg) for arg in input.args] + return await handle_execution_with_errors( + super().start_workflow, + input, + )