Skip to content

Commit 068a445

Browse files
committed
feat: add async openai tracer
1 parent 1fbc1c4 commit 068a445

File tree

3 files changed

+310
-10
lines changed

3 files changed

+310
-10
lines changed

src/openlayer/lib/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def trace_openai(client):
3939
return openai_tracer.trace_openai(client)
4040

4141

42+
def trace_async_openai(client):
43+
"""Trace OpenAI chat completions."""
44+
# pylint: disable=import-outside-toplevel
45+
import openai
46+
47+
from .integrations import async_openai_tracer
48+
49+
if not isinstance(client, (openai.AsyncOpenAI, openai.AsyncAzureOpenAI)):
50+
raise ValueError("Invalid client. Please provide an OpenAI client.")
51+
return async_openai_tracer.trace_async_openai(client)
52+
53+
4254
def trace_openai_assistant_thread_run(client, run):
4355
"""Trace OpenAI Assistant thread run."""
4456
# pylint: disable=import-outside-toplevel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""Module with methods used to trace async OpenAI / Azure OpenAI LLMs."""
2+
3+
import json
4+
import logging
5+
import time
6+
from functools import wraps
7+
from typing import Any, Dict, Iterator, Optional, Union
8+
9+
import openai
10+
11+
from .openai_tracer import (
12+
get_model_parameters,
13+
create_trace_args,
14+
add_to_trace,
15+
parse_non_streaming_output_data,
16+
)
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def trace_async_openai(
22+
client: Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI],
23+
) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
24+
"""Patch the AsyncOpenAI or AsyncAzureOpenAI client to trace chat completions.
25+
26+
The following information is collected for each chat completion:
27+
- start_time: The time when the completion was requested.
28+
- end_time: The time when the completion was received.
29+
- latency: The time it took to generate the completion.
30+
- tokens: The total number of tokens used to generate the completion.
31+
- prompt_tokens: The number of tokens in the prompt.
32+
- completion_tokens: The number of tokens in the completion.
33+
- model: The model used to generate the completion.
34+
- model_parameters: The parameters used to configure the model.
35+
- raw_output: The raw output of the model.
36+
- inputs: The inputs used to generate the completion.
37+
- metadata: Additional metadata about the completion. For example, the time it
38+
took to generate the first token, when streaming.
39+
40+
Parameters
41+
----------
42+
client : Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]
43+
The AsyncOpenAI client to patch.
44+
45+
Returns
46+
-------
47+
Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]
48+
The patched AsyncOpenAI client.
49+
"""
50+
is_azure_openai = isinstance(client, openai.AsyncAzureOpenAI)
51+
create_func = client.chat.completions.create
52+
53+
@wraps(create_func)
54+
async def traced_create_func(*args, **kwargs):
55+
inference_id = kwargs.pop("inference_id", None)
56+
stream = kwargs.get("stream", False)
57+
58+
if stream:
59+
return await handle_async_streaming_create(
60+
*args,
61+
**kwargs,
62+
create_func=create_func,
63+
inference_id=inference_id,
64+
is_azure_openai=is_azure_openai,
65+
)
66+
return await handle_async_non_streaming_create(
67+
*args,
68+
**kwargs,
69+
create_func=create_func,
70+
inference_id=inference_id,
71+
is_azure_openai=is_azure_openai,
72+
)
73+
74+
client.chat.completions.create = traced_create_func
75+
return client
76+
77+
78+
async def handle_async_streaming_create(
79+
create_func: callable,
80+
*args,
81+
is_azure_openai: bool = False,
82+
inference_id: Optional[str] = None,
83+
**kwargs,
84+
) -> Iterator[Any]:
85+
"""Handles the create method when streaming is enabled.
86+
87+
Parameters
88+
----------
89+
create_func : callable
90+
The create method to handle.
91+
is_azure_openai : bool, optional
92+
Whether the client is an Azure OpenAI client, by default False
93+
inference_id : Optional[str], optional
94+
A user-generated inference id, by default None
95+
96+
Returns
97+
-------
98+
Iterator[Any]
99+
A generator that yields the chunks of the completion.
100+
"""
101+
chunks = await create_func(*args, **kwargs)
102+
return await stream_async_chunks(
103+
chunks=chunks,
104+
kwargs=kwargs,
105+
inference_id=inference_id,
106+
is_azure_openai=is_azure_openai,
107+
)
108+
109+
110+
async def stream_async_chunks(
111+
chunks: Iterator[Any],
112+
kwargs: Dict[str, any],
113+
is_azure_openai: bool = False,
114+
inference_id: Optional[str] = None,
115+
):
116+
"""Streams the chunks of the completion and traces the completion."""
117+
collected_output_data = []
118+
collected_function_call = {
119+
"name": "",
120+
"arguments": "",
121+
}
122+
raw_outputs = []
123+
start_time = time.time()
124+
end_time = None
125+
first_token_time = None
126+
num_of_completion_tokens = None
127+
latency = None
128+
try:
129+
i = 0
130+
async for chunk in chunks:
131+
raw_outputs.append(chunk.model_dump())
132+
if i == 0:
133+
first_token_time = time.time()
134+
if i > 0:
135+
num_of_completion_tokens = i + 1
136+
i += 1
137+
138+
delta = chunk.choices[0].delta
139+
140+
if delta.content:
141+
collected_output_data.append(delta.content)
142+
elif delta.function_call:
143+
if delta.function_call.name:
144+
collected_function_call["name"] += delta.function_call.name
145+
if delta.function_call.arguments:
146+
collected_function_call["arguments"] += (
147+
delta.function_call.arguments
148+
)
149+
elif delta.tool_calls:
150+
if delta.tool_calls[0].function.name:
151+
collected_function_call["name"] += delta.tool_calls[0].function.name
152+
if delta.tool_calls[0].function.arguments:
153+
collected_function_call["arguments"] += delta.tool_calls[
154+
0
155+
].function.arguments
156+
157+
yield chunk
158+
end_time = time.time()
159+
latency = (end_time - start_time) * 1000
160+
# pylint: disable=broad-except
161+
except Exception as e:
162+
logger.error("Failed yield chunk. %s", e)
163+
finally:
164+
# Try to add step to the trace
165+
try:
166+
collected_output_data = [
167+
message for message in collected_output_data if message is not None
168+
]
169+
if collected_output_data:
170+
output_data = "".join(collected_output_data)
171+
else:
172+
collected_function_call["arguments"] = json.loads(
173+
collected_function_call["arguments"]
174+
)
175+
output_data = collected_function_call
176+
177+
trace_args = create_trace_args(
178+
end_time=end_time,
179+
inputs={"prompt": kwargs["messages"]},
180+
output=output_data,
181+
latency=latency,
182+
tokens=num_of_completion_tokens,
183+
prompt_tokens=0,
184+
completion_tokens=num_of_completion_tokens,
185+
model=kwargs.get("model"),
186+
model_parameters=get_model_parameters(kwargs),
187+
raw_output=raw_outputs,
188+
id=inference_id,
189+
metadata={
190+
"timeToFirstToken": (
191+
(first_token_time - start_time) * 1000
192+
if first_token_time
193+
else None
194+
)
195+
},
196+
)
197+
add_to_trace(
198+
**trace_args,
199+
is_azure_openai=is_azure_openai,
200+
)
201+
202+
# pylint: disable=broad-except
203+
except Exception as e:
204+
logger.error(
205+
"Failed to trace the create chat completion request with Openlayer. %s",
206+
e,
207+
)
208+
209+
210+
async def handle_async_non_streaming_create(
211+
create_func: callable,
212+
*args,
213+
is_azure_openai: bool = False,
214+
inference_id: Optional[str] = None,
215+
**kwargs,
216+
) -> "openai.types.chat.chat_completion.ChatCompletion":
217+
"""Handles the create method when streaming is disabled.
218+
219+
Parameters
220+
----------
221+
create_func : callable
222+
The create method to handle.
223+
is_azure_openai : bool, optional
224+
Whether the client is an Azure OpenAI client, by default False
225+
inference_id : Optional[str], optional
226+
A user-generated inference id, by default None
227+
228+
Returns
229+
-------
230+
openai.types.chat.chat_completion.ChatCompletion
231+
The chat completion response.
232+
"""
233+
start_time = time.time()
234+
response = await create_func(*args, **kwargs)
235+
end_time = time.time()
236+
237+
# Try to add step to the trace
238+
try:
239+
output_data = parse_non_streaming_output_data(response)
240+
trace_args = create_trace_args(
241+
end_time=end_time,
242+
inputs={"prompt": kwargs["messages"]},
243+
output=output_data,
244+
latency=(end_time - start_time) * 1000,
245+
tokens=response.usage.total_tokens,
246+
prompt_tokens=response.usage.prompt_tokens,
247+
completion_tokens=response.usage.completion_tokens,
248+
model=response.model,
249+
model_parameters=get_model_parameters(kwargs),
250+
raw_output=response.model_dump(),
251+
id=inference_id,
252+
)
253+
254+
add_to_trace(
255+
is_azure_openai=is_azure_openai,
256+
**trace_args,
257+
)
258+
# pylint: disable=broad-except
259+
except Exception as e:
260+
logger.error(
261+
"Failed to trace the create chat completion request with Openlayer. %s", e
262+
)
263+
264+
return response

src/openlayer/lib/integrations/openai_tracer.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,16 @@ def stream_chunks(
137137
if delta.function_call.name:
138138
collected_function_call["name"] += delta.function_call.name
139139
if delta.function_call.arguments:
140-
collected_function_call["arguments"] += delta.function_call.arguments
140+
collected_function_call["arguments"] += (
141+
delta.function_call.arguments
142+
)
141143
elif delta.tool_calls:
142144
if delta.tool_calls[0].function.name:
143145
collected_function_call["name"] += delta.tool_calls[0].function.name
144146
if delta.tool_calls[0].function.arguments:
145-
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
147+
collected_function_call["arguments"] += delta.tool_calls[
148+
0
149+
].function.arguments
146150

147151
yield chunk
148152
end_time = time.time()
@@ -153,11 +157,15 @@ def stream_chunks(
153157
finally:
154158
# Try to add step to the trace
155159
try:
156-
collected_output_data = [message for message in collected_output_data if message is not None]
160+
collected_output_data = [
161+
message for message in collected_output_data if message is not None
162+
]
157163
if collected_output_data:
158164
output_data = "".join(collected_output_data)
159165
else:
160-
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
166+
collected_function_call["arguments"] = json.loads(
167+
collected_function_call["arguments"]
168+
)
161169
output_data = collected_function_call
162170

163171
trace_args = create_trace_args(
@@ -172,7 +180,13 @@ def stream_chunks(
172180
model_parameters=get_model_parameters(kwargs),
173181
raw_output=raw_outputs,
174182
id=inference_id,
175-
metadata={"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None)},
183+
metadata={
184+
"timeToFirstToken": (
185+
(first_token_time - start_time) * 1000
186+
if first_token_time
187+
else None
188+
)
189+
},
176190
)
177191
add_to_trace(
178192
**trace_args,
@@ -240,8 +254,12 @@ def create_trace_args(
240254
def add_to_trace(is_azure_openai: bool = False, **kwargs) -> None:
241255
"""Add a chat completion step to the trace."""
242256
if is_azure_openai:
243-
tracer.add_chat_completion_step_to_trace(**kwargs, name="Azure OpenAI Chat Completion", provider="Azure")
244-
tracer.add_chat_completion_step_to_trace(**kwargs, name="OpenAI Chat Completion", provider="OpenAI")
257+
tracer.add_chat_completion_step_to_trace(
258+
**kwargs, name="Azure OpenAI Chat Completion", provider="Azure"
259+
)
260+
tracer.add_chat_completion_step_to_trace(
261+
**kwargs, name="OpenAI Chat Completion", provider="OpenAI"
262+
)
245263

246264

247265
def handle_non_streaming_create(
@@ -294,7 +312,9 @@ def handle_non_streaming_create(
294312
)
295313
# pylint: disable=broad-except
296314
except Exception as e:
297-
logger.error("Failed to trace the create chat completion request with Openlayer. %s", e)
315+
logger.error(
316+
"Failed to trace the create chat completion request with Openlayer. %s", e
317+
)
298318

299319
return response
300320

@@ -336,7 +356,9 @@ def parse_non_streaming_output_data(
336356

337357

338358
# --------------------------- OpenAI Assistants API -------------------------- #
339-
def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.beta.threads.run.Run") -> None:
359+
def trace_openai_assistant_thread_run(
360+
client: openai.OpenAI, run: "openai.types.beta.threads.run.Run"
361+
) -> None:
340362
"""Trace a run from an OpenAI assistant.
341363
342364
Once the run is completed, the thread data is published to Openlayer,
@@ -353,7 +375,9 @@ def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.
353375
metadata = _extract_run_metadata(run)
354376

355377
# Convert thread to prompt
356-
messages = client.beta.threads.messages.list(thread_id=run.thread_id, order="asc")
378+
messages = client.beta.threads.messages.list(
379+
thread_id=run.thread_id, order="asc"
380+
)
357381
prompt = _thread_messages_to_prompt(messages)
358382

359383
# Add step to the trace

0 commit comments

Comments
 (0)