Skip to content

Commit

Permalink
Fix: Fixes empty input in the langchain stream calls (#538) (#545)
Browse files Browse the repository at this point in the history
Co-authored-by: Noble Varghese <noblekvarghese96@gmail.com>
  • Loading branch information
hassiebp and noble-varghese authored Apr 8, 2024
1 parent a95e071 commit 3f795cd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
14 changes: 8 additions & 6 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def on_chain_start(
"input": inputs,
"version": self.version,
}

if parent_run_id is None:
if self.root_span is None:
self.runs[run_id] = self.trace.span(**content)
Expand Down Expand Up @@ -326,8 +325,9 @@ def on_chain_end(
self.runs[run_id] = self.runs[run_id].end(
output=outputs, version=self.version
)

self._update_trace_and_remove_state(run_id, parent_run_id, outputs)
self._update_trace_and_remove_state(
run_id, parent_run_id, outputs, input=kwargs.get("inputs")
)
except Exception as e:
self.log.exception(e)

Expand All @@ -350,7 +350,9 @@ def on_chain_error(
version=self.version,
)

self._update_trace_and_remove_state(run_id, parent_run_id, error)
self._update_trace_and_remove_state(
run_id, parent_run_id, error, input=kwargs.get("inputs")
)

except Exception as e:
self.log.exception(e)
Expand Down Expand Up @@ -732,7 +734,7 @@ def _report_error(self, error: dict):
)

def _update_trace_and_remove_state(
self, run_id: str, parent_run_id: Optional[str], output: any
self, run_id: str, parent_run_id: Optional[str], output: any, **kwargs: Any
):
"""Update the trace with the output of the current run. Called at every finish callback event."""
if (
Expand All @@ -742,7 +744,7 @@ def _update_trace_and_remove_state(
and self.trace.id
== str(run_id) # The trace was generated by langchain and not by the user
):
self.trace = self.trace.update(output=output)
self.trace = self.trace.update(output=output, **kwargs)
del self.runs[run_id]

def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
Expand Down
14 changes: 10 additions & 4 deletions tests/test_langchain_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ async def test_chains_abatch_in_completions_models(model_name):
trace = api.trace.get(callback.get_trace_id())
generationList = list(filter(lambda o: o.type == "GENERATION", trace.observations))
assert len(generationList) != 0

assert len(trace.observations) == 4
for generation in generationList:
assert trace.name == name
Expand Down Expand Up @@ -627,7 +626,7 @@ async def test_chains_ainvoke_chat_models(model_name):
Introduction: This is an engaging introduction for the blog post on the topic above:"""
)
chain = prompt1 | model | StrOutputParser()
_ = chain.invoke(
res = await chain.ainvoke(
{"topic": "The Impact of Climate Change"},
config={"callbacks": [callback]},
)
Expand All @@ -641,6 +640,8 @@ async def test_chains_ainvoke_chat_models(model_name):

assert len(trace.observations) == 4
assert trace.name == name
assert trace.input == {"topic": "The Impact of Climate Change"}
assert trace.output == res
for generation in generationList:
assert generation.model == model_name
assert generation.input is not None
Expand Down Expand Up @@ -674,7 +675,7 @@ async def test_chains_ainvoke_completions_models(model_name):
Introduction: This is an engaging introduction for the blog post on the topic above:"""
)
chain = prompt1 | model | StrOutputParser()
_ = chain.invoke(
res = await chain.ainvoke(
{"topic": "The Impact of Climate Change"},
config={"callbacks": [callback]},
)
Expand All @@ -687,7 +688,8 @@ async def test_chains_ainvoke_completions_models(model_name):
assert len(generationList) != 0

generation = generationList[0]

assert trace.input == {"topic": "The Impact of Climate Change"}
assert trace.output == res
assert len(trace.observations) == 4
assert trace.name == name
assert generation.model == model_name
Expand Down Expand Up @@ -738,6 +740,8 @@ async def test_chains_astream_chat_models(model_name):

generation = generationList[0]

assert trace.input == {"topic": "The Impact of Climate Change"}
assert trace.output == "".join(response_str)
assert len(response_str) > 1 # To check there are more than one chunk.
assert len(trace.observations) == 4
assert trace.name == name
Expand Down Expand Up @@ -791,6 +795,8 @@ async def test_chains_astream_completions_models(model_name):

generation = generationList[0]

assert trace.input == {"topic": "The Impact of Climate Change"}
assert trace.output == "".join(response_str)
assert len(response_str) > 1 # To check there are more than one chunk.
assert len(trace.observations) == 4
assert trace.name == name
Expand Down

0 comments on commit 3f795cd

Please # to comment.