Skip to content

Commit 42adcc8

Browse files
maxdebayserflaviabeo
authored andcommitted
Support Cross encoder models (vllm-project#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com> Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
1 parent 4ac4813 commit 42adcc8

28 files changed

+1370
-62
lines changed

docs/source/serving/openai_compatible_server.md

+142
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,148 @@ We currently support the following OpenAI APIs:
4444
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
4545
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
4646

47+
## Score API for Cross Encoder Models
48+
49+
vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
50+
51+
A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.
52+
53+
### Example of usage for a pair of a string and a list of texts
54+
55+
In this case, the model will compare the first given text to each of the texts containing the list.
56+
57+
```bash
58+
curl -X 'POST' \
59+
'http://127.0.0.1:8000/v1/score' \
60+
-H 'accept: application/json' \
61+
-H 'Content-Type: application/json' \
62+
-d '{
63+
"model": "BAAI/bge-reranker-v2-m3",
64+
"text_1": "What is the capital of France?",
65+
"text_2": [
66+
"The capital of Brazil is Brasilia.",
67+
"The capital of France is Paris."
68+
]
69+
}'
70+
```
71+
72+
Response:
73+
74+
```bash
75+
{
76+
"id": "score-request-id",
77+
"object": "list",
78+
"created": 693570,
79+
"model": "BAAI/bge-reranker-v2-m3",
80+
"data": [
81+
{
82+
"index": 0,
83+
"object": "score",
84+
"score": [
85+
0.001094818115234375
86+
]
87+
},
88+
{
89+
"index": 1,
90+
"object": "score",
91+
"score": [
92+
1
93+
]
94+
}
95+
],
96+
"usage": {}
97+
}
98+
```
99+
100+
### Example of usage for a pair of two lists of texts
101+
102+
In this case, the model will compare the one by one, making pairs by same index correspondent in each list.
103+
104+
```bash
105+
curl -X 'POST' \
106+
'http://127.0.0.1:8000/v1/score' \
107+
-H 'accept: application/json' \
108+
-H 'Content-Type: application/json' \
109+
-d '{
110+
"model": "BAAI/bge-reranker-v2-m3",
111+
"encoding_format": "float",
112+
"text_1": [
113+
"What is the capital of Brazil?",
114+
"What is the capital of France?"
115+
],
116+
"text_2": [
117+
"The capital of Brazil is Brasilia.",
118+
"The capital of France is Paris."
119+
]
120+
}'
121+
```
122+
123+
Response:
124+
125+
```bash
126+
{
127+
"id": "score-request-id",
128+
"object": "list",
129+
"created": 693447,
130+
"model": "BAAI/bge-reranker-v2-m3",
131+
"data": [
132+
{
133+
"index": 0,
134+
"object": "score",
135+
"score": [
136+
1
137+
]
138+
},
139+
{
140+
"index": 1,
141+
"object": "score",
142+
"score": [
143+
1
144+
]
145+
}
146+
],
147+
"usage": {}
148+
}
149+
```
150+
151+
### Example of usage for a pair of two strings
152+
153+
In this case, the model will compare the strings of texts.
154+
155+
```bash
156+
curl -X 'POST' \
157+
'http://127.0.0.1:8000/v1/score' \
158+
-H 'accept: application/json' \
159+
-H 'Content-Type: application/json' \
160+
-d '{
161+
"model": "BAAI/bge-reranker-v2-m3",
162+
"encoding_format": "float",
163+
"text_1": "What is the capital of France?",
164+
"text_2": "The capital of France is Paris."
165+
}'
166+
```
167+
168+
Response:
169+
170+
```bash
171+
{
172+
"id": "score-request-id",
173+
"object": "list",
174+
"created": 693447,
175+
"model": "BAAI/bge-reranker-v2-m3",
176+
"data": [
177+
{
178+
"index": 0,
179+
"object": "score",
180+
"score": [
181+
1
182+
]
183+
}
184+
],
185+
"usage": {}
186+
}
187+
```
188+
47189
## Extra Parameters
48190

49191
vLLM supports a set of parameters that are not part of the OpenAI API.
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Examples Python client Score for Cross Encoder Models
2+
"""
3+
4+
import argparse
5+
import json
6+
import pprint
7+
8+
import requests
9+
10+
11+
def post_http_request(prompt: json, api_url: str) -> requests.Response:
12+
headers = {"User-Agent": "Test Client"}
13+
response = requests.post(api_url, headers=headers, json=prompt)
14+
return response
15+
16+
17+
if __name__ == "__main__":
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("--host", type=str, default="localhost")
20+
parser.add_argument("--port", type=int, default=8000)
21+
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
22+
args = parser.parse_args()
23+
api_url = f"http://{args.host}:{args.port}/v1/score"
24+
25+
model_name = args.model
26+
27+
text_1 = "What is the capital of France?"
28+
text_2 = [
29+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
30+
]
31+
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
32+
score_response = post_http_request(prompt=prompt, api_url=api_url)
33+
print("Prompt for text_1 is string and text_2 is a list:")
34+
pprint.pprint(prompt)
35+
print("Score Response:")
36+
pprint.pprint(score_response.data)
37+
38+
text_1 = [
39+
"What is the capital of Brazil?", "What is the capital of France?"
40+
]
41+
text_2 = [
42+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
43+
]
44+
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
45+
score_response = post_http_request(prompt=prompt, api_url=api_url)
46+
print("Prompt for text_1 and text_2 are lists:")
47+
pprint.pprint(prompt)
48+
print("Score Response:")
49+
pprint.pprint(score_response.data)
50+
51+
text_1 = "What is the capital of Brazil?"
52+
text_2 = "The capital of Brazil is Brasilia."
53+
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
54+
score_response = post_http_request(prompt=prompt, api_url=api_url)
55+
print("Prompt for text_1 and text_2 are strings:")
56+
pprint.pprint(prompt)
57+
print("Score Response:")
58+
pprint.pprint(score_response.data)

tests/conftest.py

+20
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def __init__(
265265
model_kwargs: Optional[Dict[str, Any]] = None,
266266
is_embedding_model: bool = False,
267267
is_sentence_transformer: bool = False,
268+
is_cross_encoder: bool = False,
268269
skip_tokenizer_init: bool = False,
269270
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
270271
postprocess_inputs: Callable[..., BatchEncoding] = identity,
@@ -282,6 +283,14 @@ def __init__(
282283
device="cpu",
283284
trust_remote_code=True,
284285
).to(dtype=torch_dtype))
286+
elif is_cross_encoder:
287+
# Lazy init required for AMD CI
288+
from sentence_transformers import CrossEncoder
289+
self.model = CrossEncoder(model_name,
290+
device="cpu",
291+
trust_remote_code=True)
292+
self.model.model = self.wrap_device(self.model.model)\
293+
.to(dtype=torch_dtype)
285294
else:
286295
model_kwargs = model_kwargs if model_kwargs is not None else {}
287296
self.model = self.wrap_device(
@@ -625,6 +634,9 @@ def generate_encoder_decoder_greedy_logprobs_limit(
625634
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
626635
return self.model.encode(prompts)
627636

637+
def predict(self, prompts: List[List[str]]) -> torch.Tensor:
638+
return self.model.predict(prompts, convert_to_tensor=True)
639+
628640
def __enter__(self):
629641
return self
630642

@@ -898,6 +910,14 @@ def encode(
898910
req_outputs = self.model.encode(inputs)
899911
return [req_output.outputs.embedding for req_output in req_outputs]
900912

913+
def score(
914+
self,
915+
text_1: Union[str, List[str]],
916+
text_2: Union[str, List[str]],
917+
) -> List[List[float]]:
918+
req_outputs = self.model.score(text_1, text_2)
919+
return [req_output.outputs.embedding for req_output in req_outputs]
920+
901921
def __enter__(self):
902922
return self
903923

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pytest
2+
import requests
3+
4+
from vllm.entrypoints.openai.protocol import ScoreResponse
5+
6+
from ...utils import RemoteOpenAIServer
7+
8+
MODEL_NAME = "BAAI/bge-reranker-v2-m3"
9+
10+
11+
@pytest.fixture(scope="module")
12+
def server():
13+
args = [
14+
"--enforce-eager",
15+
]
16+
17+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
18+
yield remote_server
19+
20+
21+
@pytest.mark.asyncio
22+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
23+
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
24+
model_name: str):
25+
text_1 = "What is the capital of France?"
26+
text_2 = [
27+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
28+
]
29+
30+
score_response = requests.post(server.url_for("v1/score"),
31+
json={
32+
"model": model_name,
33+
"text_1": text_1,
34+
"text_2": text_2,
35+
})
36+
score_response.raise_for_status()
37+
score = ScoreResponse.model_validate(score_response.json())
38+
39+
assert score.id is not None
40+
assert score.data is not None
41+
assert len(score.data) == 2
42+
assert score.data[0].score[0] <= 0.01
43+
assert score.data[1].score[0] >= 0.9
44+
45+
46+
@pytest.mark.asyncio
47+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
48+
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
49+
model_name: str):
50+
text_1 = [
51+
"What is the capital of the United States?",
52+
"What is the capital of France?"
53+
]
54+
text_2 = [
55+
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
56+
]
57+
58+
score_response = requests.post(server.url_for("v1/score"),
59+
json={
60+
"model": model_name,
61+
"text_1": text_1,
62+
"text_2": text_2,
63+
})
64+
score_response.raise_for_status()
65+
score = ScoreResponse.model_validate(score_response.json())
66+
67+
assert score.id is not None
68+
assert score.data is not None
69+
assert len(score.data) == 2
70+
assert score.data[0].score[0] <= 0.01
71+
assert score.data[1].score[0] >= 0.9
72+
73+
74+
@pytest.mark.asyncio
75+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
76+
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
77+
model_name: str):
78+
text_1 = "What is the capital of France?"
79+
text_2 = "The capital of France is Paris."
80+
81+
score_response = requests.post(server.url_for("v1/score"),
82+
json={
83+
"model": model_name,
84+
"text_1": text_1,
85+
"text_2": text_2,
86+
})
87+
score_response.raise_for_status()
88+
score = ScoreResponse.model_validate(score_response.json())
89+
90+
assert score.id is not None
91+
assert score.data is not None
92+
assert len(score.data) == 1
93+
assert score.data[0].score[0] >= 0.9

0 commit comments

Comments
 (0)