From 03ce6897d375ecd6e7afb06e7b012258c960b1e4 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sun, 23 Mar 2025 14:49:57 +0100 Subject: [PATCH 1/2] feat: add openai-agents tracing processor --- examples/attachment.py | 2 +- examples/langchain_toolcall.py | 13 +- examples/langchain_variable.py | 5 +- examples/llamaindex.py | 7 +- examples/llamaindex_workflow.py | 13 +- examples/main.py | 2 +- examples/multimodal.py | 7 +- examples/openai_agents.py | 27 ++ literalai/api/__init__.py | 2 +- literalai/api/helpers/attachment_helpers.py | 3 +- literalai/api/helpers/prompt_helpers.py | 6 +- literalai/api/helpers/score_helpers.py | 7 +- literalai/api/helpers/step_helpers.py | 5 +- literalai/api/helpers/thread_helpers.py | 5 +- literalai/api/helpers/user_helpers.py | 5 +- literalai/cache/prompt_helpers.py | 2 +- literalai/callback/langchain_callback.py | 2 + literalai/callback/openai_agents_processor.py | 360 ++++++++++++++++++ literalai/client.py | 13 + literalai/helper.py | 9 +- .../llamaindex/event_handler.py | 37 +- .../llamaindex/span_handler.py | 10 +- literalai/observability/filter.py | 4 +- literalai/observability/generation.py | 8 +- tests/e2e/test_llamaindex.py | 3 +- tests/e2e/test_openai.py | 2 +- tests/unit/test_cache.py | 45 ++- 27 files changed, 497 insertions(+), 107 deletions(-) create mode 100644 examples/openai_agents.py create mode 100644 literalai/callback/openai_agents_processor.py diff --git a/examples/attachment.py b/examples/attachment.py index d0743f64..a3454fb0 100644 --- a/examples/attachment.py +++ b/examples/attachment.py @@ -32,7 +32,7 @@ async def main(): "url": "https://api.github.com/repos/chainlit/chainlit", "mime": "application/json", "metadata": {"test": "test"}, - } + }, ) print(attachment.to_dict()) diff --git a/examples/langchain_toolcall.py b/examples/langchain_toolcall.py index f9fc60fd..c3bd5e49 100644 --- a/examples/langchain_toolcall.py +++ b/examples/langchain_toolcall.py @@ -1,15 +1,12 @@ -from literalai import LiteralClient - -from langchain_openai import ChatOpenAI # type: ignore +from dotenv import load_dotenv +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain.agents.agent import BaseSingleActionAgent from langchain_community.tools.tavily_search import TavilySearchResults - -from langchain.agents import create_tool_calling_agent -from langchain.agents import AgentExecutor from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.config import RunnableConfig -from langchain.agents.agent import BaseSingleActionAgent +from langchain_openai import ChatOpenAI # type: ignore -from dotenv import load_dotenv +from literalai import LiteralClient # Add OPENAI_API_KEY and TAVILY_API_KEY for this example. load_dotenv() diff --git a/examples/langchain_variable.py b/examples/langchain_variable.py index d0da94cd..8ea7de68 100644 --- a/examples/langchain_variable.py +++ b/examples/langchain_variable.py @@ -1,8 +1,7 @@ +from dotenv import load_dotenv from langchain.chat_models import init_chat_model -from literalai import LiteralClient - -from dotenv import load_dotenv +from literalai import LiteralClient load_dotenv() diff --git a/examples/llamaindex.py b/examples/llamaindex.py index b1f8d88b..bc41d947 100644 --- a/examples/llamaindex.py +++ b/examples/llamaindex.py @@ -1,6 +1,7 @@ -from literalai import LiteralClient -from llama_index.core import Document, VectorStoreIndex from dotenv import load_dotenv +from llama_index.core import Document, VectorStoreIndex + +from literalai import LiteralClient load_dotenv() @@ -14,7 +15,7 @@ questions = [ "Tell me about LLMs", "How do you fine-tune a neural network ?", - "What is RAG ?" + "What is RAG ?", ] # No context, create a Thread (it will be named after the first user query) diff --git a/examples/llamaindex_workflow.py b/examples/llamaindex_workflow.py index c580e40b..e5e55e47 100644 --- a/examples/llamaindex_workflow.py +++ b/examples/llamaindex_workflow.py @@ -1,12 +1,8 @@ import asyncio -from llama_index.core.workflow import ( - Event, - StartEvent, - StopEvent, - Workflow, - step, -) + +from llama_index.core.workflow import Event, StartEvent, StopEvent, Workflow, step from llama_index.llms.openai import OpenAI + from literalai.client import LiteralClient lai_client = LiteralClient() @@ -16,7 +12,8 @@ class JokeEvent(Event): joke: str -class RewriteJoke(Event): + +class RewriteJoke(Event): joke: str diff --git a/examples/main.py b/examples/main.py index ab528f80..2a2a387b 100644 --- a/examples/main.py +++ b/examples/main.py @@ -24,7 +24,7 @@ def get_completion(welcome_message, text): { "role": "system", "content": "Tell an inspiring quote to the user, mentioning their name. Be extremely supportive while " - "keeping it short. Write one sentence per line.", + "keeping it short. Write one sentence per line.", }, { "role": "assistant", diff --git a/examples/multimodal.py b/examples/multimodal.py index f8e4a545..c0051306 100644 --- a/examples/multimodal.py +++ b/examples/multimodal.py @@ -1,12 +1,11 @@ import base64 -import requests # type: ignore import time -from literalai import LiteralClient -from openai import OpenAI - +import requests # type: ignore from dotenv import load_dotenv +from openai import OpenAI +from literalai import LiteralClient load_dotenv() diff --git a/examples/openai_agents.py b/examples/openai_agents.py new file mode 100644 index 00000000..3bd802ac --- /dev/null +++ b/examples/openai_agents.py @@ -0,0 +1,27 @@ +import asyncio + +from agents import Agent, Runner, set_trace_processors, trace +from dotenv import load_dotenv + +from literalai import LiteralClient + +load_dotenv() + +client = LiteralClient() + + +async def main(): + agent = Agent(name="Joke generator", instructions="Tell funny jokes.") + + with trace("Joke workflow"): + first_result = await Runner.run(agent, "Tell me a joke") + second_result = await Runner.run( + agent, f"Rate this joke: {first_result.final_output}" + ) + print(f"Joke: {first_result.final_output}") + print(f"Rating: {second_result.final_output}") + + +if __name__ == "__main__": + set_trace_processors([client.openai_agents_tracing_processor()]) + asyncio.run(main()) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 39f00332..46c766d0 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,4 +1,4 @@ -from literalai.api.synchronous import LiteralAPI from literalai.api.asynchronous import AsyncLiteralAPI +from literalai.api.synchronous import LiteralAPI __all__ = ["LiteralAPI", "AsyncLiteralAPI"] diff --git a/literalai/api/helpers/attachment_helpers.py b/literalai/api/helpers/attachment_helpers.py index cb02772e..24d20196 100644 --- a/literalai/api/helpers/attachment_helpers.py +++ b/literalai/api/helpers/attachment_helpers.py @@ -1,9 +1,8 @@ import mimetypes from typing import Dict, Optional, TypedDict, Union -from literalai.observability.step import Attachment - from literalai.api.helpers import gql +from literalai.observability.step import Attachment def create_attachment_helper( diff --git a/literalai/api/helpers/prompt_helpers.py b/literalai/api/helpers/prompt_helpers.py index 00a98816..5b4760e3 100644 --- a/literalai/api/helpers/prompt_helpers.py +++ b/literalai/api/helpers/prompt_helpers.py @@ -1,18 +1,16 @@ import logging -from typing import TYPE_CHECKING, Optional, TypedDict, Callable +from typing import TYPE_CHECKING, Callable, Optional, TypedDict +from literalai.cache.prompt_helpers import put_prompt from literalai.observability.generation import GenerationMessage from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -from literalai.cache.prompt_helpers import put_prompt - if TYPE_CHECKING: from literalai.api import LiteralAPI from literalai.cache.shared_cache import SharedCache from literalai.api.helpers import gql - logger = logging.getLogger(__name__) diff --git a/literalai/api/helpers/score_helpers.py b/literalai/api/helpers/score_helpers.py index ee6bd9de..0d993dbe 100644 --- a/literalai/api/helpers/score_helpers.py +++ b/literalai/api/helpers/score_helpers.py @@ -1,11 +1,10 @@ import math from typing import Any, Dict, List, Optional, TypedDict -from literalai.observability.filter import scores_filters, scores_order_by -from literalai.my_types import PaginatedResponse -from literalai.observability.step import ScoreType, ScoreDict, Score - from literalai.api.helpers import gql +from literalai.my_types import PaginatedResponse +from literalai.observability.filter import scores_filters, scores_order_by +from literalai.observability.step import Score, ScoreDict, ScoreType def get_scores_helper( diff --git a/literalai/api/helpers/step_helpers.py b/literalai/api/helpers/step_helpers.py index 9d7f8c13..886e3dbd 100644 --- a/literalai/api/helpers/step_helpers.py +++ b/literalai/api/helpers/step_helpers.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Union -from literalai.observability.filter import steps_filters, steps_order_by +from literalai.api.helpers import gql from literalai.my_types import PaginatedResponse +from literalai.observability.filter import steps_filters, steps_order_by from literalai.observability.step import Step, StepDict, StepType -from literalai.api.helpers import gql - def create_step_helper( thread_id: Optional[str] = None, diff --git a/literalai/api/helpers/thread_helpers.py b/literalai/api/helpers/thread_helpers.py index f4c3b591..98e07adc 100644 --- a/literalai/api/helpers/thread_helpers.py +++ b/literalai/api/helpers/thread_helpers.py @@ -1,12 +1,11 @@ from typing import Any, Dict, List, Optional -from literalai.observability.filter import threads_filters, threads_order_by +from literalai.api.helpers import gql from literalai.my_types import PaginatedResponse +from literalai.observability.filter import threads_filters, threads_order_by from literalai.observability.step import StepType from literalai.observability.thread import Thread -from literalai.api.helpers import gql - def get_threads_helper( first: Optional[int] = None, diff --git a/literalai/api/helpers/user_helpers.py b/literalai/api/helpers/user_helpers.py index ad533d2f..bd6ba22e 100644 --- a/literalai/api/helpers/user_helpers.py +++ b/literalai/api/helpers/user_helpers.py @@ -1,9 +1,8 @@ from typing import Any, Dict, Optional -from literalai.observability.filter import users_filters -from literalai.my_types import PaginatedResponse, User - from literalai.api.helpers import gql +from literalai.my_types import PaginatedResponse, User +from literalai.observability.filter import users_filters def get_users_helper( diff --git a/literalai/cache/prompt_helpers.py b/literalai/cache/prompt_helpers.py index 56646f96..d28034b9 100644 --- a/literalai/cache/prompt_helpers.py +++ b/literalai/cache/prompt_helpers.py @@ -1,5 +1,5 @@ -from literalai.prompt_engineering.prompt import Prompt from literalai.cache.shared_cache import SharedCache +from literalai.prompt_engineering.prompt import Prompt def put_prompt(cache: SharedCache, prompt: Prompt): diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 55897595..119c3406 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -193,6 +193,8 @@ def _build_llm_settings( ) model_keys = ["azure_deployment", "deployment_name", "model", "model_name"] model = next((settings[k] for k in model_keys if k in settings), None) + if isinstance(model, str): + model = model.replace("models/", "") tools = None if "functions" in settings: tools = [ diff --git a/literalai/callback/openai_agents_processor.py b/literalai/callback/openai_agents_processor.py new file mode 100644 index 00000000..430fce11 --- /dev/null +++ b/literalai/callback/openai_agents_processor.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import logging +from importlib.metadata import version +from typing import TYPE_CHECKING, Any, Dict, Optional + +from literalai.context import active_thread_var +from literalai.helper import ensure_values_serializable, force_dict +from literalai.observability.generation import ChatGeneration +from literalai.observability.step import Step, StepType + +if TYPE_CHECKING: + from literalai.client import LiteralClient + + +logger = logging.getLogger(__name__) + + +def get_openai_agents_tracing_processor(): + try: + version("openai-agents") + except Exception: + raise Exception( + "Please install agents to use the agents tracing processor. " + "You can install it with `pip install openai-agents.`" + ) + + from agents.tracing import Span, Trace, TracingProcessor + from agents.tracing.span_data import ( + AgentSpanData, + CustomSpanData, + FunctionSpanData, + GenerationSpanData, + GuardrailSpanData, + HandoffSpanData, + ResponseSpanData, + SpanData, + ) + + def _get_span_name(obj: Span) -> str: + if hasattr(data := obj.span_data, "name") and isinstance( + name := data.name, str + ): + return name + if isinstance(obj.span_data, HandoffSpanData) and obj.span_data.to_agent: + return f"handoff to {obj.span_data.to_agent}" + return obj.span_data.type + + def _get_span_type(obj: SpanData) -> StepType: + if isinstance(obj, AgentSpanData): + return "run" + if isinstance(obj, FunctionSpanData): + return "tool" + if isinstance(obj, GenerationSpanData): + return "llm" + if isinstance(obj, ResponseSpanData): + return "llm" + if isinstance(obj, HandoffSpanData): + return "tool" + if isinstance(obj, CustomSpanData): + return "undefined" + if isinstance(obj, GuardrailSpanData): + return "undefined" + return "undefined" + + def _extract_function_span_data( + span_data: FunctionSpanData, + ) -> Dict[str, Any]: + return { + "inputs": force_dict(ensure_values_serializable(span_data.input)), + "outputs": force_dict(ensure_values_serializable(span_data.output)), + } + + def _extract_generation_span_data(span_data: GenerationSpanData) -> Dict[str, Any]: + """Extract data from a generation span.""" + + generation = ChatGeneration( + provider=getattr(span_data, "provider", "unknown"), + model=getattr(span_data, "model", None), + settings=getattr(span_data, "model_config", None), + token_count=span_data.usage.get("total_tokens"), + input_token_count=span_data.usage.get("prompt_tokens"), + output_token_count=span_data.usage.get("completion_tokens"), + messages=span_data.input, + ) + + return {"generation": generation} + + def _extract_response_span_data(span_data: ResponseSpanData) -> Dict[str, Any]: + """Extract data from a response span.""" + data: Dict[str, Any] = {} + + generation = ChatGeneration(provider="openai") + + metadata = {} + + if span_data.input is not None: + generation.messages = span_data.input + metadata["instructions"] = span_data.response.instructions + + if span_data.response is not None: + response = span_data.response.model_dump(exclude_none=True, mode="json") + output = response.pop("output", []) + if output: + generation.message_completion = { + "role": "assistant", + "content": span_data.response.output_text, + } + + if usage := response.pop("usage", None): + if "output_tokens" in usage: + generation.output_token_count = usage.pop("output_tokens") + if "input_tokens" in usage: + generation.input_token_count = usage.pop("input_tokens") + if "total_tokens" in usage: + generation.token_count = usage.pop("total_tokens") + + metadata["usage"] = usage + + generation.settings = { + k: v + for k, v in response.items() + if k + in ( + "max_output_tokens", + "model", + "parallel_tool_calls", + "reasoning", + "temperature", + "text", + "tool_choice", + "tools", + "top_p", + "truncation", + ) + } + generation.model = generation.settings.get("model") + generation.tools = generation.settings.pop("tools", []) + data["generation"] = generation + data["metadata"] = metadata + + return data + + def _extract_agent_span_data(span_data: AgentSpanData) -> Dict[str, Any]: + """Extract data from an agent span.""" + breakpoint() + return { + "inputs": force_dict( + ensure_values_serializable(getattr(span_data, "input", None)) + ), + "outputs": force_dict( + ensure_values_serializable(getattr(span_data, "output", None)) + ), + "invocation_params": { + "tools": getattr(span_data, "tools", []), + "handoffs": getattr(span_data, "handoffs", []), + }, + "metadata": { + "output_type": getattr(span_data, "output_type", None), + "type": "agent", + }, + } + + def _extract_handoff_span_data(span_data: HandoffSpanData) -> Dict[str, Any]: + """Extract data from a handoff span.""" + return { + "inputs": { + "from_agent": getattr(span_data, "from_agent", None), + "to_agent": getattr(span_data, "to_agent", None), + "content": getattr(span_data, "content", None), + }, + "metadata": { + "type": "handoff", + }, + } + + def _extract_guardrail_span_data(span_data: GuardrailSpanData) -> Dict[str, Any]: + """Extract data from a guardrail span.""" + return { + "inputs": force_dict( + ensure_values_serializable(getattr(span_data, "input", None)) + ), + "outputs": force_dict( + ensure_values_serializable(getattr(span_data, "output", None)) + ), + "metadata": { + "triggered": getattr(span_data, "triggered", False), + "type": "guardrail", + }, + } + + def _extract_custom_span_data(span_data: CustomSpanData) -> Dict[str, Any]: + """Extract data from a custom span.""" + data = {"metadata": {"type": "custom"}} + + if hasattr(span_data, "data"): + if isinstance(span_data.data, dict): + data["metadata"].update(span_data.data) + else: + data["metadata"]["data"] = str(span_data.data) + + return data + + def _extract_span_data(span: Span[Any]) -> Dict[str, Any]: + """Extract data from a span based on its type.""" + data: Dict[str, Any] = {} + + if isinstance(span.span_data, FunctionSpanData): + data.update(_extract_function_span_data(span.span_data)) + elif isinstance(span.span_data, GenerationSpanData): + data.update(_extract_generation_span_data(span.span_data)) + elif isinstance(span.span_data, ResponseSpanData): + data.update(_extract_response_span_data(span.span_data)) + elif isinstance(span.span_data, AgentSpanData): + data.update(_extract_agent_span_data(span.span_data)) + elif isinstance(span.span_data, HandoffSpanData): + data.update(_extract_handoff_span_data(span.span_data)) + elif isinstance(span.span_data, GuardrailSpanData): + data.update(_extract_guardrail_span_data(span.span_data)) + elif isinstance(span.span_data, CustomSpanData): + data.update(_extract_custom_span_data(span.span_data)) + + return data + + class AgentsTracingProcessor(TracingProcessor): + """Processor for sending agent traces to LiteralAI.""" + + def __init__( + self, + literal_client: LiteralClient, + thread_id: Optional[str] = None, + include_metadata: bool = True, + ) -> None: + """ + Initialize the LiteralAI tracing processor. + + Args: + literal_client: The LiteralAI client to use for sending traces. + thread_id: Optional thread ID to associate with the traces. + include_metadata: Whether to include metadata in the traces. + """ + self.client = literal_client + self.thread_id = thread_id + self.include_metadata = include_metadata + self._steps: Dict[str, Step] = {} + self._root_steps: Dict[str, Step] = {} + + def on_trace_start(self, trace: Trace) -> None: + """Called when a trace is started. + + Args: + trace: The trace that started. + """ + thread_id = self.thread_id + if not thread_id and (active_thread := active_thread_var.get()): + thread_id = active_thread.id + + root_step = self.client.start_step( + name=trace.name, + type="run", + thread_id=thread_id, + ) + + trace_dict = trace.export() or {} + metadata = trace_dict.get("metadata") or {} + + root_step.metadata = metadata + + self._root_steps[trace.trace_id] = root_step + + def on_trace_end(self, trace: Trace) -> None: + """Called when a trace is finished. + + Args: + trace: The trace that started. + """ + if root_step := self._root_steps.pop(trace.trace_id, None): + root_step.end() + self.client.event_processor.flush() + + def on_span_start(self, span: Span[Any]) -> None: + """Called when a span is started. + + Args: + span: The span that started. + """ + if not span.started_at: + return + + parent_step = None + if span.parent_id and span.parent_id in self._steps: + parent_step = self._steps[span.parent_id] + elif span.trace_id in self._root_steps: + parent_step = self._root_steps[span.trace_id] + + span_name = _get_span_name(span) + span_type = _get_span_type(span.span_data) + + thread_id = self.thread_id + if not thread_id and (active_thread := active_thread_var.get()): + thread_id = active_thread.id + + step = self.client.start_step( + name=span_name, + type=span_type, # type: ignore + parent_id=parent_step.id if parent_step else None, + thread_id=thread_id, + ) + + self._steps[span.span_id] = step + + def on_span_end(self, span: Span[Any]) -> None: + """Called when a span is finished. Should not block or raise exceptions. + + Args: + span: The span that finished. + """ + if not (step := self._steps.pop(span.span_id, None)): + return + + try: + # Extract all data from the span using our extraction helpers + extracted_data = _extract_span_data(span) + + # Set inputs on the step if available + if "inputs" in extracted_data: + step.input = extracted_data["inputs"] + + # Set outputs on the step if available + if "outputs" in extracted_data: + step.output = extracted_data["outputs"] + + if "generation" in extracted_data: + step.generation = extracted_data["generation"] + + # Add metadata from extracted data + if "metadata" in extracted_data: + if step.metadata is None: + step.metadata = {} + step.metadata.update(extracted_data["metadata"]) + + # Handle errors + if span.error: + step.error = span.error.message + + except Exception as e: + logger.error(f"Error processing span: {e}") + step.error = str(e) + finally: + step.end() + + def force_flush(self) -> None: + """Forces an immediate flush of all queued spans/traces.""" + self.client.event_processor.flush() + + def shutdown(self) -> None: + """Called when the application stops.""" + self.client.flush_and_stop() + + return AgentsTracingProcessor diff --git a/literalai/client.py b/literalai/client.py index 586b41db..3d06ea7d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -9,6 +9,9 @@ from literalai.api import AsyncLiteralAPI, LiteralAPI from literalai.callback.langchain_callback import get_langchain_callback +from literalai.callback.openai_agents_processor import ( + get_openai_agents_tracing_processor, +) from literalai.context import active_root_run_var, active_steps_var, active_thread_var from literalai.environment import EnvContextManager, env_decorator from literalai.evaluation.experiment_item_run import ( @@ -161,6 +164,16 @@ def langchain_callback( **kwargs, ) + def openai_agents_tracing_processor( + self, + **kwargs: Any, + ): + tracing_processor = get_openai_agents_tracing_processor() + return tracing_processor( + self.to_sync(), + **kwargs, + ) + def thread( self, original_function=None, diff --git a/literalai/helper.py b/literalai/helper.py index e8af97b2..eca6a49f 100644 --- a/literalai/helper.py +++ b/literalai/helper.py @@ -14,8 +14,7 @@ def ensure_values_serializable(data): if isinstance(data, BaseModel): return filter_none_values(data.model_dump()) - - if isinstance(data, dict): + elif isinstance(data, dict): return {key: ensure_values_serializable(value) for key, value in data.items()} elif isinstance(data, list): return [ensure_values_serializable(item) for item in data] @@ -29,6 +28,12 @@ def ensure_values_serializable(data): return str(data) # Fallback: convert other types to string +def force_dict(data, default_key="content"): + if not isinstance(data, dict): + return {default_key: data} + return data + + def utc_now(): dt = datetime.utcnow() return dt.isoformat() + "Z" diff --git a/literalai/instrumentation/llamaindex/event_handler.py b/literalai/instrumentation/llamaindex/event_handler.py index 29f09b0a..bd953d94 100644 --- a/literalai/instrumentation/llamaindex/event_handler.py +++ b/literalai/instrumentation/llamaindex/event_handler.py @@ -1,47 +1,38 @@ -import uuid import logging +import uuid from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.base.response.schema import Response, StreamingResponse from llama_index.core.instrumentation.event_handlers import BaseEventHandler from llama_index.core.instrumentation.events import BaseEvent -from pydantic import PrivateAttr - -from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler -from literalai.context import active_thread_var - from llama_index.core.instrumentation.events.agent import ( - AgentChatWithStepStartEvent, AgentChatWithStepEndEvent, - AgentRunStepStartEvent, + AgentChatWithStepStartEvent, AgentRunStepEndEvent, + AgentRunStepStartEvent, ) from llama_index.core.instrumentation.events.embedding import ( - EmbeddingStartEvent, EmbeddingEndEvent, + EmbeddingStartEvent, ) - -from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent -from llama_index.core.instrumentation.events.retrieval import ( - RetrievalEndEvent, - RetrievalStartEvent, -) - -from llama_index.core.base.llms.types import MessageRole, ChatMessage -from llama_index.core.base.response.schema import Response, StreamingResponse - from llama_index.core.instrumentation.events.llm import ( LLMChatEndEvent, LLMChatStartEvent, ) - -from llama_index.core.instrumentation.events.synthesis import ( - SynthesizeEndEvent, +from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent +from llama_index.core.instrumentation.events.retrieval import ( + RetrievalEndEvent, + RetrievalStartEvent, ) - +from llama_index.core.instrumentation.events.synthesis import SynthesizeEndEvent from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk +from pydantic import PrivateAttr +from literalai.context import active_thread_var +from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler from literalai.observability.generation import ( ChatGeneration, GenerationMessage, diff --git a/literalai/instrumentation/llamaindex/span_handler.py b/literalai/instrumentation/llamaindex/span_handler.py index 7c5c9f15..3fc7bec3 100644 --- a/literalai/instrumentation/llamaindex/span_handler.py +++ b/literalai/instrumentation/llamaindex/span_handler.py @@ -1,9 +1,11 @@ -from typing_extensions import TypedDict -from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler -from llama_index.core.instrumentation.span import SimpleSpan +import uuid from typing import Any, Dict, Optional + +from llama_index.core.instrumentation.span import SimpleSpan +from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler from llama_index.core.query_engine import RetrieverQueryEngine -import uuid +from typing_extensions import TypedDict + from literalai.context import active_thread_var literalai_uuid_namespace = uuid.UUID("05f6b2b5-a912-47bd-958f-98a9c4496322") diff --git a/literalai/observability/filter.py b/literalai/observability/filter.py index 00752bfd..2cea60f4 100644 --- a/literalai/observability/filter.py +++ b/literalai/observability/filter.py @@ -1,7 +1,7 @@ -from typing_extensions import TypedDict - from typing import Any, Generic, List, Literal, Optional, TypeVar, Union +from typing_extensions import TypedDict + Field = TypeVar("Field") Operators = TypeVar("Operators") Value = TypeVar("Value") diff --git a/literalai/observability/generation.py b/literalai/observability/generation.py index ed8c1a1c..acf91485 100644 --- a/literalai/observability/generation.py +++ b/literalai/observability/generation.py @@ -1,13 +1,11 @@ -from enum import unique, Enum -from typing import Optional, Union, List, Dict, Literal +from enum import Enum, unique +from typing import Dict, List, Literal, Optional, Union from pydantic import Field from pydantic.dataclasses import dataclass - from typing_extensions import TypedDict -from literalai.my_types import TextContent, ImageUrlContent, Utils - +from literalai.my_types import ImageUrlContent, TextContent, Utils GenerationMessageRole = Literal["user", "assistant", "tool", "function", "system"] diff --git a/tests/e2e/test_llamaindex.py b/tests/e2e/test_llamaindex.py index bee26524..29a70f41 100644 --- a/tests/e2e/test_llamaindex.py +++ b/tests/e2e/test_llamaindex.py @@ -2,11 +2,10 @@ import urllib.parse import pytest +from dotenv import load_dotenv from literalai import LiteralClient -from dotenv import load_dotenv - load_dotenv() diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 92d8c3c7..db374316 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -7,7 +7,7 @@ from pytest_httpx import HTTPXMock from literalai import LiteralClient -from literalai.observability.generation import CompletionGeneration, ChatGeneration +from literalai.observability.generation import ChatGeneration, CompletionGeneration @pytest.fixture diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index ccf751c9..405e662b 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -1,15 +1,16 @@ import pytest -from literalai.prompt_engineering.prompt import Prompt from literalai.api import LiteralAPI -from literalai.cache.shared_cache import SharedCache from literalai.cache.prompt_helpers import put_prompt +from literalai.cache.shared_cache import SharedCache +from literalai.prompt_engineering.prompt import Prompt + def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: return Prompt( api=LiteralAPI(), id=id, - name=name, + name=name, version=version, created_at="", updated_at="", @@ -21,73 +22,79 @@ def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Promp provider="", settings={}, variables=[], - variables_default_values=None + variables_default_values=None, ) + def test_singleton_instance(): """Test that SharedCache maintains singleton pattern""" cache1 = SharedCache() cache2 = SharedCache() assert cache1 is cache2 - + + def test_get_empty_cache(): """Test getting from empty cache returns None""" cache = SharedCache() - cache.clear() - + cache.clear() + assert cache.get_cache() == {} + def test_put_and_get_prompt_by_id_by_name_version_by_name(): """Test storing and retrieving prompt by ID by name-version by name""" cache = SharedCache() cache.clear() - + prompt = default_prompt() put_prompt(cache, prompt) - + retrieved_by_id = cache.get(id="1") assert retrieved_by_id is prompt - + retrieved_by_name_version = cache.get(name="test", version=1) assert retrieved_by_name_version is prompt - + retrieved_by_name = cache.get(name="test") assert retrieved_by_name is prompt + def test_clear_cache(): """Test clearing the cache""" cache = SharedCache() prompt = default_prompt() put_prompt(cache, prompt) - + cache.clear() assert cache.get_cache() == {} + def test_update_existing_prompt(): """Test updating an existing prompt""" cache = SharedCache() cache.clear() - + prompt1 = default_prompt() prompt2 = default_prompt(id="1", version=2) - + cache.put_prompt(prompt1) cache.put_prompt(prompt2) - + retrieved = cache.get(id="1") assert retrieved is prompt2 assert retrieved.version == 2 + def test_error_handling(): """Test error handling for invalid inputs""" cache = SharedCache() cache.clear() - + assert cache.get_cache() == {} assert cache.get(key="") is None - + with pytest.raises(TypeError): cache.get(5) # type: ignore - + with pytest.raises(TypeError): - cache.put(5, "test") # type: ignore \ No newline at end of file + cache.put(5, "test") # type: ignore From a7c28bdc32322c0ef8e6c0cba530b2fc6ece0eb3 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sun, 23 Mar 2025 14:58:59 +0100 Subject: [PATCH 2/2] fix: ci --- mypy.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy.ini b/mypy.ini index 7e4bc038..e2b2811a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,6 +7,9 @@ ignore_missing_imports = True [mypy-chevron.*] ignore_missing_imports = True +[mypy-agents.*] +ignore_missing_imports = True + [mypy-langchain_community.*] ignore_missing_imports = True