Skip to content

Implement reranking for Voyage Models #10521

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ def add_known_models():
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.voyage.rerank.transformation import VoyageRerankConfig
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
Expand Down
26 changes: 26 additions & 0 deletions litellm/llms/voyage/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from typing import Union

import httpx


class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
request=self.request,
response=self.response,
)
21 changes: 1 addition & 20 deletions litellm/llms/voyage/embedding/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,7 @@
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage


class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
from litellm.llms.voyage.common_utils import VoyageError


class VoyageEmbeddingConfig(BaseEmbeddingConfig):
Expand Down
155 changes: 155 additions & 0 deletions litellm/llms/voyage/rerank/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import uuid
from typing import Any, Dict, List, Optional, Union

import httpx

from litellm.secret_managers.main import get_secret_str
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
from litellm.llms.voyage.common_utils import VoyageError
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponseDocument,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)

class VoyageRerankConfig(BaseRerankConfig):
def __init__(self) -> None:
pass

def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.voyageai.com/v1/rerank"

def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("VOYAGE_API_KEY")
or get_secret_str("VOYAGE_AI_API_KEY")
or get_secret_str("VOYAGE_AI_TOKEN")
)
return {
"Authorization": f"Bearer {api_key}",
"content-type": "application/json",
}

def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_k",
"return_documents",
]

def map_cohere_rerank_params(
self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Voyage rerank params
"""
optional_params = {}
supported_params = self.get_supported_cohere_rerank_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v

# Voyage API uses top_k instead of top_n
# Assign top_k to top_n if top_n is not None
if top_n is not None:
optional_params["top_k"] = top_n
optional_params["top_n"] = None

return OptionalRerankParams(
**optional_params,
)
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
# Transform request to RerankRequest spec
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Voyage rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
# Voyage API uses top_k instead of top_n
top_k=optional_rerank_params.get("top_k", None),
return_documents=optional_rerank_params.get("return_documents", None),
)
return rerank_request.model_dump(exclude_none=True)

def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Voyage rerank response
No transformation required, litellm follows Voyage API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise VoyageError(
message=raw_response.text, status_code=raw_response.status_code
)
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
_tokens = RerankTokens(
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
output_tokens=(
raw_response_json.get("usage", {}).get("total_tokens", 0)
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
),
)
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)

voyage_results: List[RerankResponseResult] = []
if raw_response_json.get("data"):
for result in raw_response_json.get("data"):
_rerank_response = RerankResponseResult(
index=result.get("index"),
relevance_score=result.get("relevance_score"),
)
if result.get("document"):
_rerank_response["document"] = RerankResponseDocument(
text=result.get("document")
)
voyage_results.append(_rerank_response)
if voyage_results is None:
raise ValueError(f"No results found in the response={raw_response_json}")

return RerankResponse(
id=raw_response_json.get("id") or str(uuid.uuid4()),
results=voyage_results,
meta=rerank_meta,
) # Return response
26 changes: 25 additions & 1 deletion litellm/rerank_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy", "voyage"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
Expand Down Expand Up @@ -323,6 +323,30 @@ def rerank( # noqa: PLR0915
logging_obj=litellm_logging_obj,
client=client,
)
elif _custom_llm_provider == "voyage":
api_key = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("VOYAGE_API_BASE") # type: ignore
)
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

Expand Down
2 changes: 2 additions & 0 deletions litellm/types/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class RerankRequest(BaseModel):
return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = None
max_tokens_per_doc: Optional[int] = None
top_k: Optional[int] = None


class OptionalRerankParams(TypedDict, total=False):
Expand All @@ -29,6 +30,7 @@ class OptionalRerankParams(TypedDict, total=False):
return_documents: Optional[bool]
max_chunks_per_doc: Optional[int]
max_tokens_per_doc: Optional[int]
top_k: Optional[int]


class RerankBilledUnits(TypedDict, total=False):
Expand Down
2 changes: 2 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6377,6 +6377,8 @@ def get_provider_rerank_config(
return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
elif litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageRerankConfig()
return litellm.CohereRerankConfig()

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
import pytest

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path


import litellm
from unittest.mock import patch



### Rerank Tests
@pytest.mark.asyncio()
async def test_voyage_ai_rerank():
mock_response = AsyncMock()

def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 2, "relevance_score": 0.84375}],
"usage": {"total_tokens": 150},
}

mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200

expected_payload = {
"model": "rerank-model",
"query": "What is the capital of the United States?",
# Voyage API uses top_k instead of top_n
"top_k": 1,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
}

with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="voyage/rerank-model",
query="What is the capital of the United States?",
documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."],
top_n=1, # This will be converted to top_k internally
api_base="https://api.voyageai.ai"
)

print("async re rank response: ", response)

# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.voyageai.ai/v1/rerank"

request_data = json.loads(args_to_api)
print("request data to voyage ai", json.dumps(request_data, indent=4))
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_k"] == expected_payload["top_k"]
assert request_data["model"] == expected_payload["model"]

assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["output_tokens"] == 150

Loading