Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add prompt_name to feature-extraction + update types #2363

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,9 @@ def feature_extraction(
text: str,
*,
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
truncation_direction: Optional[Literal["Left", "Right"]] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
Expand All @@ -956,9 +958,17 @@ def feature_extraction(
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
prompt_name (`str`, *optional*):
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the prompt text will be prepended before any text to encode.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
truncation_direction (`Literal["Left", "Right"]`, *optional*):
Which side of the input should be truncated when `truncate=True` is passed.

Returns:
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
Expand All @@ -983,8 +993,12 @@ def feature_extraction(
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if prompt_name is not None:
payload["prompt_name"] = prompt_name
if truncate is not None:
payload["truncate"] = truncate
if truncation_direction is not None:
payload["truncation_direction"] = truncation_direction
response = self.post(json=payload, model=model, task="feature-extraction")
np = _import_numpy()
return np.array(_bytes_to_dict(response), dtype="float32")
Expand Down
14 changes: 14 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,9 @@ async def feature_extraction(
text: str,
*,
normalize: Optional[bool] = None,
prompt_name: Optional[str] = None,
truncate: Optional[bool] = None,
truncation_direction: Optional[Literal["Left", "Right"]] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
Expand All @@ -960,9 +962,17 @@ async def feature_extraction(
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
prompt_name (`str`, *optional*):
The name of the prompt that should be used by for encoding. If not set, no prompt will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...},
then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the prompt text will be prepended before any text to encode.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
truncation_direction (`Literal["Left", "Right"]`, *optional*):
Which side of the input should be truncated when `truncate=True` is passed.

Returns:
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
Expand All @@ -988,8 +998,12 @@ async def feature_extraction(
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if prompt_name is not None:
payload["prompt_name"] = prompt_name
if truncate is not None:
payload["truncate"] = truncate
if truncation_direction is not None:
payload["truncation_direction"] = truncation_direction
response = await self.post(json=payload, model=model, task="feature-extraction")
np = _import_numpy()
return np.array(_bytes_to_dict(response), dtype="float32")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Literal, Optional

from .base import BaseInferenceType


FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]


@dataclass
class FeatureExtractionInput(BaseInferenceType):
"""Feature Extraction Input.
Expand All @@ -17,6 +20,18 @@ class FeatureExtractionInput(BaseInferenceType):
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
"""

inputs: Union[List[str], str]
inputs: str
"""The text to embed."""
normalize: Optional[bool] = None
prompt_name: Optional[str] = None
"""The name of the prompt that should be used by for encoding. If not set, no prompt
will be applied.
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
...},
then the sentence "What is the capital of France?" will be encoded as
"query: What is the capital of France?" because the prompt text will be prepended before
any text to encode.
"""
truncate: Optional[bool] = None
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None
Loading