From 882409cd12bf78bedc599748a8cb9631da2fe175 Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Sat, 3 May 2025 14:34:44 +0530 Subject: [PATCH 1/3] Add Voyage Rerank functionality and related utilities - Introduced VoyageRerankConfig class for handling rerank requests to the Voyage API. - Added VoyageError class for error handling specific to Voyage operations. - Updated ProviderConfigManager to include support for the Voyage provider. - Refactored existing code to utilize the new VoyageError class. - Created tests for Voyage rerank functionality to ensure proper request and response handling. - Updated rerank API to support Voyage as a custom LLM provider. --- litellm/__init__.py | 1 + litellm/llms/voyage/common_utils.py | 26 +++ .../llms/voyage/embedding/transformation.py | 21 +-- litellm/llms/voyage/rerank/transformation.py | 155 ++++++++++++++++++ litellm/rerank_api/main.py | 26 ++- litellm/types/rerank.py | 2 + litellm/utils.py | 2 + .../llms/voyage/rerank/test_transformation.py | 80 +++++++++ 8 files changed, 292 insertions(+), 21 deletions(-) create mode 100644 litellm/llms/voyage/common_utils.py create mode 100644 litellm/llms/voyage/rerank/transformation.py create mode 100644 tests/litellm/llms/voyage/rerank/test_transformation.py diff --git a/litellm/__init__.py b/litellm/__init__.py index a015eceb0abb..cacb60621b60 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -953,6 +953,7 @@ def add_known_models(): from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from .llms.groq.chat.transformation import GroqChatConfig from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig +from .llms.voyage.rerank.transformation import VoyageRerankConfig from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.mistral.mistral_chat_transformation import MistralConfig diff --git a/litellm/llms/voyage/common_utils.py b/litellm/llms/voyage/common_utils.py new file mode 100644 index 000000000000..87edfbcd8806 --- /dev/null +++ b/litellm/llms/voyage/common_utils.py @@ -0,0 +1,26 @@ +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from typing import Union + +import httpx + + +class VoyageError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Union[dict, httpx.Headers] = {}, + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://api.voyageai.com/v1/embeddings" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + headers=headers, + request=self.request, + response=self.response, + ) \ No newline at end of file diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py index 91811e03927d..f66f5d3c8525 100644 --- a/litellm/llms/voyage/embedding/transformation.py +++ b/litellm/llms/voyage/embedding/transformation.py @@ -8,26 +8,7 @@ from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues from litellm.types.utils import EmbeddingResponse, Usage - - -class VoyageError(BaseLLMException): - def __init__( - self, - status_code: int, - message: str, - headers: Union[dict, httpx.Headers] = {}, - ): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url="https://api.voyageai.com/v1/embeddings" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - status_code=status_code, - message=message, - headers=headers, - ) +from litellm.llms.voyage.common_utils import VoyageError class VoyageEmbeddingConfig(BaseEmbeddingConfig): diff --git a/litellm/llms/voyage/rerank/transformation.py b/litellm/llms/voyage/rerank/transformation.py new file mode 100644 index 000000000000..e6903c9885ea --- /dev/null +++ b/litellm/llms/voyage/rerank/transformation.py @@ -0,0 +1,155 @@ +import uuid +from typing import Any, Dict, List, Optional, Union + +import httpx + +from litellm.secret_managers.main import get_secret_str +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse +from litellm.llms.voyage.common_utils import VoyageError +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponseDocument, + RerankResponseMeta, + RerankResponseResult, + RerankTokens, +) + +class VoyageRerankConfig(BaseRerankConfig): + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/v1/rerank"): + api_base = f"{api_base}/v1/rerank" + return api_base + return "https://api.voyageai.com/v1/rerank" + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("VOYAGE_API_KEY") + or get_secret_str("VOYAGE_AI_API_KEY") + or get_secret_str("VOYAGE_AI_TOKEN") + ) + return { + "Authorization": f"Bearer {api_key}", + "content-type": "application/json", + } + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_k", + "return_documents", + ] + + def map_cohere_rerank_params( + self, + non_default_params: dict, + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + """ + Map Voyage rerank params + """ + optional_params = {} + supported_params = self.get_supported_cohere_rerank_params(model) + for k, v in non_default_params.items(): + if k in supported_params: + optional_params[k] = v + + # Voyage API uses top_k instead of top_n + # Assign top_k to top_n if top_n is not None + if top_n is not None: + optional_params["top_k"] = top_n + optional_params["top_n"] = None + + return OptionalRerankParams( + **optional_params, + ) + def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict: + # Transform request to RerankRequest spec + if "query" not in optional_rerank_params: + raise ValueError("query is required for Cohere rerank") + if "documents" not in optional_rerank_params: + raise ValueError("documents is required for Voyage rerank") + rerank_request = RerankRequest( + model=model, + query=optional_rerank_params["query"], + documents=optional_rerank_params["documents"], + # Voyage API uses top_k instead of top_n + top_k=optional_rerank_params.get("top_k", None), + return_documents=optional_rerank_params.get("return_documents", None), + ) + return rerank_request.model_dump(exclude_none=True) + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Voyage rerank response + No transformation required, litellm follows Voyage API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise VoyageError( + message=raw_response.text, status_code=raw_response.status_code + ) + _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {})) + _tokens = RerankTokens( + input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), + output_tokens=( + raw_response_json.get("usage", {}).get("total_tokens", 0) + - raw_response_json.get("usage", {}).get("prompt_tokens", 0) + ), + ) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + voyage_results: List[RerankResponseResult] = [] + if raw_response_json.get("data"): + for result in raw_response_json.get("data"): + _rerank_response = RerankResponseResult( + index=result.get("index"), + relevance_score=result.get("relevance_score"), + ) + if result.get("document"): + _rerank_response["document"] = RerankResponseDocument( + text=result.get("document") + ) + voyage_results.append(_rerank_response) + if voyage_results is None: + raise ValueError(f"No results found in the response={raw_response_json}") + + return RerankResponse( + id=raw_response_json.get("id") or str(uuid.uuid4()), + results=voyage_results, + meta=rerank_meta, + ) # Return response \ No newline at end of file diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 9307ce5a5500..3e26de60d5d7 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915 query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[ - Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"] + Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy", "voyage"] ] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, @@ -323,6 +323,30 @@ def rerank( # noqa: PLR0915 logging_obj=litellm_logging_obj, client=client, ) + elif _custom_llm_provider == "voyage": + api_key = ( + dynamic_api_key or optional_params.api_key or litellm.api_key + ) + api_base = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("VOYAGE_API_BASE") # type: ignore + ) + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + provider_config=rerank_provider_config, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) else: raise ValueError(f"Unsupported provider: {_custom_llm_provider}") diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index fb6dae0d1df9..92fbbbcad43a 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -19,6 +19,7 @@ class RerankRequest(BaseModel): return_documents: Optional[bool] = None max_chunks_per_doc: Optional[int] = None max_tokens_per_doc: Optional[int] = None + top_k: Optional[int] = None class OptionalRerankParams(TypedDict, total=False): @@ -29,6 +30,7 @@ class OptionalRerankParams(TypedDict, total=False): return_documents: Optional[bool] max_chunks_per_doc: Optional[int] max_tokens_per_doc: Optional[int] + top_k: Optional[int] class RerankBilledUnits(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index d586962ba954..92c39026110f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6377,6 +6377,8 @@ def get_provider_rerank_config( return litellm.InfinityRerankConfig() elif litellm.LlmProviders.JINA_AI == provider: return litellm.JinaAIRerankConfig() + elif litellm.LlmProviders.VOYAGE == provider: + return litellm.VoyageRerankConfig() return litellm.CohereRerankConfig() @staticmethod diff --git a/tests/litellm/llms/voyage/rerank/test_transformation.py b/tests/litellm/llms/voyage/rerank/test_transformation.py new file mode 100644 index 000000000000..aca3e135f107 --- /dev/null +++ b/tests/litellm/llms/voyage/rerank/test_transformation.py @@ -0,0 +1,80 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import litellm +from unittest.mock import patch + + + +### Rerank Tests +@pytest.mark.asyncio() +async def test_voyage_ai_rerank(): + mock_response = AsyncMock() + + def return_val(): + return { + "id": "cmpl-mockid", + "results": [{"index": 2, "relevance_score": 0.84375}], + "usage": {"total_tokens": 150}, + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + expected_payload = { + "model": "rerank-model", + "query": "What is the capital of the United States?", + # Voyage API uses top_k instead of top_n + "top_k": 1, + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. is the capital of the United States.", + "Capital punishment has existed in the United States since before it was a country." + ], + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.arerank( + model="voyage/rerank-model", + query="What is the capital of the United States?", + documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."], + top_n=1, # This will be converted to top_k internally + api_base="https://api.voyageai.ai" + ) + + print("async re rank response: ", response) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + print("Arguments passed to API=", args_to_api) + print("url = ", _url) + assert _url == "https://api.voyageai.ai/v1/rerank" + + request_data = json.loads(args_to_api) + print("request data to voyage ai", json.dumps(request_data, indent=4)) + assert request_data["query"] == expected_payload["query"] + assert request_data["documents"] == expected_payload["documents"] + assert request_data["top_k"] == expected_payload["top_k"] + assert request_data["model"] == expected_payload["model"] + + assert response.id is not None + assert response.results is not None + assert response.meta["tokens"]["output_tokens"] == 150 + \ No newline at end of file From e650a8bead19e59365ac2accc02e2031de733b22 Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Wed, 7 May 2025 10:54:11 +0530 Subject: [PATCH 2/3] Created a new test file for testing the Voyage rerank transformation. --- .../test_voyage_rerank_transformation.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/litellm/llms/voyage/rerank/test_voyage_rerank_transformation.py diff --git a/tests/litellm/llms/voyage/rerank/test_voyage_rerank_transformation.py b/tests/litellm/llms/voyage/rerank/test_voyage_rerank_transformation.py new file mode 100644 index 000000000000..aca3e135f107 --- /dev/null +++ b/tests/litellm/llms/voyage/rerank/test_voyage_rerank_transformation.py @@ -0,0 +1,80 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +import litellm +from unittest.mock import patch + + + +### Rerank Tests +@pytest.mark.asyncio() +async def test_voyage_ai_rerank(): + mock_response = AsyncMock() + + def return_val(): + return { + "id": "cmpl-mockid", + "results": [{"index": 2, "relevance_score": 0.84375}], + "usage": {"total_tokens": 150}, + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + expected_payload = { + "model": "rerank-model", + "query": "What is the capital of the United States?", + # Voyage API uses top_k instead of top_n + "top_k": 1, + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. is the capital of the United States.", + "Capital punishment has existed in the United States since before it was a country." + ], + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.arerank( + model="voyage/rerank-model", + query="What is the capital of the United States?", + documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."], + top_n=1, # This will be converted to top_k internally + api_base="https://api.voyageai.ai" + ) + + print("async re rank response: ", response) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + print("Arguments passed to API=", args_to_api) + print("url = ", _url) + assert _url == "https://api.voyageai.ai/v1/rerank" + + request_data = json.loads(args_to_api) + print("request data to voyage ai", json.dumps(request_data, indent=4)) + assert request_data["query"] == expected_payload["query"] + assert request_data["documents"] == expected_payload["documents"] + assert request_data["top_k"] == expected_payload["top_k"] + assert request_data["model"] == expected_payload["model"] + + assert response.id is not None + assert response.results is not None + assert response.meta["tokens"]["output_tokens"] == 150 + \ No newline at end of file From 9028877a4b52b7f5e1b673dd5c97af396a45e831 Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Wed, 7 May 2025 11:11:04 +0530 Subject: [PATCH 3/3] Removed the old transformation file --- .../llms/voyage/rerank/test_transformation.py | 80 ------------------- 1 file changed, 80 deletions(-) delete mode 100644 tests/litellm/llms/voyage/rerank/test_transformation.py diff --git a/tests/litellm/llms/voyage/rerank/test_transformation.py b/tests/litellm/llms/voyage/rerank/test_transformation.py deleted file mode 100644 index aca3e135f107..000000000000 --- a/tests/litellm/llms/voyage/rerank/test_transformation.py +++ /dev/null @@ -1,80 +0,0 @@ -import json -import os -import sys -from datetime import datetime -from unittest.mock import AsyncMock -import pytest - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path - - -import litellm -from unittest.mock import patch - - - -### Rerank Tests -@pytest.mark.asyncio() -async def test_voyage_ai_rerank(): - mock_response = AsyncMock() - - def return_val(): - return { - "id": "cmpl-mockid", - "results": [{"index": 2, "relevance_score": 0.84375}], - "usage": {"total_tokens": 150}, - } - - mock_response.json = return_val - mock_response.headers = {"key": "value"} - mock_response.status_code = 200 - - expected_payload = { - "model": "rerank-model", - "query": "What is the capital of the United States?", - # Voyage API uses top_k instead of top_n - "top_k": 1, - "documents": [ - "Carson City is the capital city of the American state of Nevada.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", - "Washington, D.C. is the capital of the United States.", - "Capital punishment has existed in the United States since before it was a country." - ], - } - - with patch( - "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", - return_value=mock_response, - ) as mock_post: - response = await litellm.arerank( - model="voyage/rerank-model", - query="What is the capital of the United States?", - documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."], - top_n=1, # This will be converted to top_k internally - api_base="https://api.voyageai.ai" - ) - - print("async re rank response: ", response) - - # Assert - mock_post.assert_called_once() - print("call args", mock_post.call_args) - args_to_api = mock_post.call_args.kwargs["data"] - _url = mock_post.call_args.kwargs["url"] - print("Arguments passed to API=", args_to_api) - print("url = ", _url) - assert _url == "https://api.voyageai.ai/v1/rerank" - - request_data = json.loads(args_to_api) - print("request data to voyage ai", json.dumps(request_data, indent=4)) - assert request_data["query"] == expected_payload["query"] - assert request_data["documents"] == expected_payload["documents"] - assert request_data["top_k"] == expected_payload["top_k"] - assert request_data["model"] == expected_payload["model"] - - assert response.id is not None - assert response.results is not None - assert response.meta["tokens"]["output_tokens"] == 150 - \ No newline at end of file