Skip to content

fix: adding trace i/o in langfuse openai integration #532

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 10 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions langfuse/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,6 @@ def _get_langfuse_data_from_kwargs(
if parent_observation_id is not None and trace_id is None:
raise ValueError("parent_observation_id requires trace_id to be set")

if trace_id:
langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id, tags=tags)
else:
trace_id = (
decorator_context_trace_id
or langfuse.trace(
session_id=session_id, user_id=user_id, tags=tags, name=name
).id
)

metadata = kwargs.get("metadata", {})

if metadata is not None and not isinstance(metadata, dict):
Expand All @@ -271,6 +261,23 @@ def _get_langfuse_data_from_kwargs(
elif resource.type == "chat":
prompt = _extract_chat_prompt(kwargs)

is_nested_trace = False
if trace_id:
is_nested_trace = True
langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id, tags=tags)
else:
trace_id = (
decorator_context_trace_id
or langfuse.trace(
session_id=session_id,
user_id=user_id,
tags=tags,
name=name,
input=prompt,
metadata=metadata,
).id
)

modelParameters = {
"temperature": kwargs.get("temperature", 1),
"max_tokens": kwargs.get("max_tokens", float("inf")), # casing?
Expand All @@ -289,14 +296,15 @@ def _get_langfuse_data_from_kwargs(
"input": prompt,
"model_parameters": modelParameters,
"model": model,
}
}, is_nested_trace


def _get_langfuse_data_from_sync_streaming_response(
resource: OpenAiDefinition,
response,
generation: StatefulGenerationClient,
langfuse: Langfuse,
is_nested_trace,
):
responses = []
for i in response:
Expand All @@ -307,6 +315,10 @@ def _get_langfuse_data_from_sync_streaming_response(
resource, responses
)

# Avoiding the trace-update if trace-id is provided by user.
if not is_nested_trace:
langfuse.trace(id=generation.trace_id, output=completion)

_create_langfuse_update(completion, generation, completion_start_time, model=model)


Expand All @@ -315,6 +327,7 @@ async def _get_langfuse_data_from_async_streaming_response(
response,
generation: StatefulGenerationClient,
langfuse: Langfuse,
is_nested_trace,
):
responses = []
async for i in response:
Expand All @@ -325,6 +338,10 @@ async def _get_langfuse_data_from_async_streaming_response(
resource, responses
)

# Avoiding the trace-update if trace-id is provided by user.
if not is_nested_trace:
langfuse.trace(id=generation.trace_id, output=completion)

_create_langfuse_update(completion, generation, completion_start_time, model=model)


Expand Down Expand Up @@ -463,7 +480,7 @@ def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs)
start_time = _get_timestamp()
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)

generation = _get_langfuse_data_from_kwargs(
generation, is_nested_trace = _get_langfuse_data_from_kwargs(
open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args()
)
generation = new_langfuse.generation(**generation)
Expand All @@ -472,7 +489,11 @@ def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs)

if _is_streaming_response(openai_response):
return _get_langfuse_data_from_sync_streaming_response(
open_ai_resource, openai_response, generation, new_langfuse
open_ai_resource,
openai_response,
generation,
new_langfuse,
is_nested_trace,
)

else:
Expand All @@ -484,6 +505,10 @@ def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs)
model=model, output=completion, end_time=_get_timestamp(), usage=usage
)

# Avoiding the trace-update if trace-id is provided by user.
if not is_nested_trace:
new_langfuse.trace(id=generation.trace_id, output=completion)

return openai_response
except Exception as ex:
log.warning(ex)
Expand All @@ -505,7 +530,7 @@ async def _wrap_async(
start_time = _get_timestamp()
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)

generation = _get_langfuse_data_from_kwargs(
generation, is_nested_trace = _get_langfuse_data_from_kwargs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was wondering if we could return a nested function bool from the existing function which would make the implementation easier to review and also possibly reduce the touch points in the code. If we change, we change on ly inside the function and rest should fall in place.

open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args()
)
generation = new_langfuse.generation(**generation)
Expand All @@ -514,7 +539,11 @@ async def _wrap_async(

if _is_streaming_response(openai_response):
return _get_langfuse_data_from_async_streaming_response(
open_ai_resource, openai_response, generation, new_langfuse
open_ai_resource,
openai_response,
generation,
new_langfuse,
is_nested_trace,
)

else:
Expand All @@ -528,6 +557,10 @@ async def _wrap_async(
end_time=_get_timestamp(),
usage=usage,
)
# Avoiding the trace-update if trace-id is provided by user.
if not is_nested_trace:
new_langfuse.trace(id=generation.trace_id, output=completion)

return openai_response
except Exception as ex:
model = kwargs.get("model", None)
Expand Down
85 changes: 83 additions & 2 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def test_openai_chat_completion():
assert "2" in generation.data[0].output["content"]
assert generation.data[0].output["role"] == "assistant"

trace = api.trace.get(generation.data[0].trace_id)
assert trace.input == [{"role": "user", "content": "1 + 1 = "}]
assert trace.output["content"] == completion.choices[0].message.content
assert trace.output["role"] == completion.choices[0].message.role


def test_openai_chat_completion_stream():
api = get_api()
Expand All @@ -79,8 +84,9 @@ def test_openai_chat_completion_stream():
)

assert _is_streaming_response(completion)
chat_content = ""
for i in completion:
print(i)
chat_content += i.choices[0].delta.content or ""

openai.flush_langfuse()

Expand Down Expand Up @@ -110,6 +116,10 @@ def test_openai_chat_completion_stream():
assert isinstance(generation.data[0].output, str) is True
assert generation.data[0].completion_start_time is not None

trace = api.trace.get(generation.data[0].trace_id)
assert trace.input == [{"role": "user", "content": "1 + 1 = "}]
assert trace.output == chat_content


def test_openai_chat_completion_stream_fail():
api = get_api()
Expand Down Expand Up @@ -156,6 +166,10 @@ def test_openai_chat_completion_stream_fail():

openai.api_key = os.environ["OPENAI_API_KEY"]

trace = api.trace.get(generation.data[0].trace_id)
assert trace.input == [{"role": "user", "content": "1 + 1 = "}]
assert trace.output is None


def test_openai_chat_completion_with_trace():
api = get_api()
Expand Down Expand Up @@ -372,6 +386,10 @@ def test_openai_completion():
assert generation.data[0].usage.total is not None
assert generation.data[0].output == "2\n\n1 + 2 = 3\n\n2 + 3 = "

trace = api.trace.get(generation.data[0].trace_id)
assert trace.input == "1 + 1 = "
assert trace.output == completion.choices[0].text


def test_openai_completion_stream():
api = get_api()
Expand All @@ -386,8 +404,9 @@ def test_openai_completion_stream():
)

assert _is_streaming_response(completion)
content = ""
for i in completion:
print(i)
content += i.choices[0].text

openai.flush_langfuse()

Expand Down Expand Up @@ -416,6 +435,10 @@ def test_openai_completion_stream():
assert generation.data[0].output == "2\n\n1 + 2 = 3\n\n2 + 3 = "
assert generation.data[0].completion_start_time is not None

trace = api.trace.get(generation.data[0].trace_id)
assert trace.input == "1 + 1 = "
assert trace.output == content


def test_openai_completion_fail():
api = get_api()
Expand Down Expand Up @@ -921,3 +944,61 @@ def test_image_filter_url():
],
}
]


def test_openai_with_existing_trace_id():
langfuse = Langfuse()
trace = langfuse.trace(
name="docs-retrieval",
user_id="user__935d7d1d-8625-4ef4-8651-544613e7bd22",
metadata={
"email": "user@langfuse.com",
},
tags=["production"],
output="This is a standard output",
input="My custom input",
)

langfuse.flush()

api = get_api()
generation_name = create_uuid()
completion = chat_func(
name=generation_name,
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "1 + 1 = "}],
temperature=0,
metadata={"someKey": "someResponse"},
trace_id=trace.id,
)

openai.flush_langfuse()

generation = api.observations.get_many(name=generation_name, type="GENERATION")

assert len(generation.data) != 0
assert generation.data[0].name == generation_name
assert generation.data[0].metadata == {"someKey": "someResponse"}
assert len(completion.choices) != 0
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
assert generation.data[0].type == "GENERATION"
assert generation.data[0].model == "gpt-3.5-turbo-0125"
assert generation.data[0].start_time is not None
assert generation.data[0].end_time is not None
assert generation.data[0].start_time < generation.data[0].end_time
assert generation.data[0].model_parameters == {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"max_tokens": "inf",
"presence_penalty": 0,
}
assert generation.data[0].usage.input is not None
assert generation.data[0].usage.output is not None
assert generation.data[0].usage.total is not None
assert "2" in generation.data[0].output["content"]
assert generation.data[0].output["role"] == "assistant"

trace = api.trace.get(generation.data[0].trace_id)
assert trace.output == "This is a standard output"
assert trace.input == "My custom input"