diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index aa561ed5..927d588c 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -107,6 +107,7 @@ def __init__( model: str, *, api_base: str | None = None, + extra_headers: dict[str, str] | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, temperature: float | None = None, @@ -114,6 +115,7 @@ def __init__( ): self._model = model self._api_base = api_base + self._extra_headers = extra_headers self._max_tokens = max_tokens self._metadata = metadata self._temperature = temperature @@ -127,6 +129,10 @@ def model(self) -> str: def api_base(self) -> str | None: return self._api_base + @property + def extra_headers(self) -> dict[str, str] | None: + return self._extra_headers + @property def max_tokens(self) -> int | None: return self._max_tokens @@ -176,6 +182,7 @@ def complete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, custom_llm_provider=self.custom_llm_provider, + extra_headers=self.extra_headers, max_tokens=self.max_tokens, metadata=self.metadata, stop=stop, @@ -216,6 +223,7 @@ async def acomplete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, custom_llm_provider=self.custom_llm_provider, + extra_headers=self.extra_headers, max_tokens=self.max_tokens, metadata=self.metadata, stop=stop, diff --git a/tests/chat_model/cassettes/test_litellm_chat_model/test_litellm_chat_model_extra_headers.yaml b/tests/chat_model/cassettes/test_litellm_chat_model/test_litellm_chat_model_extra_headers.yaml new file mode 100644 index 00000000..7f1cacd7 --- /dev/null +++ b/tests/chat_model/cassettes/test_litellm_chat_model/test_litellm_chat_model_extra_headers.yaml @@ -0,0 +1,131 @@ +interactions: +- request: + body: '{"messages": [{"role": "user", "content": "Say hello!"}], "model": "gpt-4o", + "stream": true}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '92' + content-type: + - application/json + host: + - api.openai.com + my-extra-header: + - foo + user-agent: + - OpenAI/Python 1.59.3 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.59.3 + x-stainless-raw-response: + - 'true' + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.10.15 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: 'data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + How"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + can"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + I"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + assist"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + you"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":" + today"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-B4PbRwwnGeGvQA5rvpj5ujPOkyrWc","object":"chat.completion.chunk","created":1740391429,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_eb9dce56a8","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + + + data: [DONE] + + + ' + headers: + CF-RAY: + - 916ea0c2c9bdcb94-LAX + Connection: + - keep-alive + Content-Type: + - text/event-stream; charset=utf-8 + Date: + - Mon, 24 Feb 2025 10:03:49 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-processing-ms: + - '325' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '500' + x-ratelimit-limit-tokens: + - '30000' + x-ratelimit-remaining-requests: + - '499' + x-ratelimit-remaining-tokens: + - '29980' + x-ratelimit-reset-requests: + - 120ms + x-ratelimit-reset-tokens: + - 40ms + x-request-id: + - req_c61d69127580a9d164ff2a701a16b787 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/chat_model/test_litellm_chat_model.py b/tests/chat_model/test_litellm_chat_model.py index 34fc2e5f..4c2ce2cf 100644 --- a/tests/chat_model/test_litellm_chat_model.py +++ b/tests/chat_model/test_litellm_chat_model.py @@ -49,6 +49,20 @@ def _add_call_to_list(kwargs, completion_response, start_time, end_time): litellm.success_callback = original_success_callback +@pytest.mark.litellm_ollama +def test_litellm_chat_model_extra_headers(litellm_success_callback_calls): + """Test that provided extra_headers is passed to the litellm success callback.""" + chat_model = LitellmChatModel("gpt-4o", extra_headers={"my-extra-header": "foo"}) + assert chat_model.extra_headers == {"my-extra-header": "foo"} + chat_model.complete(messages=[UserMessage("Say hello!")]) + # There are multiple callback calls due to streaming + # Take the last one because the first is occasionally from another test + callback_call = litellm_success_callback_calls[-1] + assert callback_call["kwargs"]["optional_params"]["extra_headers"] == { + "my-extra-header": "foo" + } + + @pytest.mark.litellm_openai def test_litellm_chat_model_metadata(litellm_success_callback_calls): """Test that provided metadata is passed to the litellm success callback."""