Skip to content

Commit

Permalink
feat: added x_headers to response_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikelarg committed Jan 31, 2025
1 parent 6676b2b commit 4c84e4f
Show file tree
Hide file tree
Showing 6 changed files with 728 additions and 450 deletions.
33 changes: 26 additions & 7 deletions libs/gigachat/langchain_gigachat/chat_models/gigachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for res in response.choices:
message = _convert_dict_to_message(res.message)
x_headers = response.x_headers if response.x_headers else {}
if "x-request-id" in x_headers:
message.id = response.x_headers["x-request-id"]
finish_reason = res.finish_reason
self._check_finish_reason(finish_reason)
gen = ChatGeneration(
Expand All @@ -410,6 +413,7 @@ def _create_chat_result(self, response: Any) -> ChatResult:
llm_output = {
"token_usage": response.usage.dict(),
"model_name": response.model,
"x_headers": x_headers,
}
return ChatResult(generations=generations, llm_output=llm_output)

Expand Down Expand Up @@ -477,6 +481,7 @@ def _stream(
payload = self._build_payload(messages, **kwargs)
message_content = ""

first_chunk = True
for chunk_d in self._client.stream(payload):
chunk = {}
if not isinstance(chunk_d, dict):
Expand All @@ -492,14 +497,20 @@ def _stream(
if trim_content_to_stop_sequence(message_content, stop):
return
chunk_m = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
x_headers = chunk.get("x_headers")
x_headers = x_headers if isinstance(x_headers, dict) else {}
if "x-request-id" in x_headers:
chunk_m.id = x_headers["x-request-id"]

finish_reason = choice.get("finish_reason")
self._check_finish_reason(finish_reason)

generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)

generation_info = {}
if finish_reason:
generation_info["finish_reason"] = finish_reason
if first_chunk:
generation_info["x_headers"] = x_headers
first_chunk = False
if run_manager:
run_manager.on_llm_new_token(content)

Expand All @@ -514,6 +525,7 @@ async def _astream(
) -> AsyncIterator[ChatGenerationChunk]:
payload = self._build_payload(messages, **kwargs)
message_content = ""
first_chunk = True

async for chunk_d in self._client.astream(payload):
chunk = {}
Expand All @@ -530,13 +542,20 @@ async def _astream(
if trim_content_to_stop_sequence(message_content, stop):
return
chunk_m = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
x_headers = chunk.get("x_headers")
x_headers = x_headers if isinstance(x_headers, dict) else {}
if isinstance(x_headers, dict) and "x-request-id" in x_headers:
chunk_m.id = x_headers["x-request-id"]

finish_reason = choice.get("finish_reason")
self._check_finish_reason(finish_reason)

generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
generation_info = {}
if finish_reason:
generation_info["finish_reason"] = finish_reason
if first_chunk:
generation_info["x_headers"] = x_headers
first_chunk = False
if run_manager:
await run_manager.on_llm_new_token(content)

Expand Down
6 changes: 3 additions & 3 deletions libs/gigachat/langchain_gigachat/tools/giga_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
FewShotExamples = Optional[List[Dict[str, Any]]]


class GigaBaseTool(BaseTool):
class GigaBaseTool(BaseTool): # type: ignore[override]
"""Interface of GigaChat tools with additional properties, that GigaChat supports"""

return_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = None
Expand All @@ -42,7 +42,7 @@ class GigaBaseTool(BaseTool):
"""Few-shot examples to help the model understand how to use the tool."""


class GigaTool(GigaBaseTool, Tool):
class GigaTool(GigaBaseTool, Tool): # type: ignore[override]
pass


Expand Down Expand Up @@ -72,7 +72,7 @@ def _filter_schema_args(func: Callable) -> list[str]:
return filter_args


class GigaStructuredTool(GigaBaseTool, StructuredTool):
class GigaStructuredTool(GigaBaseTool, StructuredTool): # type: ignore[override]
@classmethod
def from_function(
cls,
Expand Down
Loading

0 comments on commit 4c84e4f

Please # to comment.