Skip to content

Commit fdb12f4

Browse files
Made litellm judge backend more robust. (#485)
* Made litellm judge backend more robust. * Added failed flag to ModelResponse. --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com>
1 parent 2073a29 commit fdb12f4

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

src/lighteval/metrics/llm_as_judge.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from tqdm import tqdm
3030

31+
from lighteval.models.model_output import ModelResponse
3132
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
3233

3334

@@ -195,20 +196,30 @@ def __call_litellm(self, prompts):
195196
def __call_api(prompt):
196197
for _ in range(self.API_MAX_RETRY):
197198
try:
198-
response = litellm.completion(
199-
model=self.model,
200-
messages=prompt,
201-
response_format={"type": "text"},
202-
max_tokens=512,
203-
n=1,
204-
caching=True,
205-
)
199+
kwargs = {
200+
"model": self.model,
201+
"messages": prompt,
202+
"response_format": {"type": "text"},
203+
"max_tokens": 512,
204+
"n": 1,
205+
"caching": True,
206+
}
207+
response = litellm.completion(**kwargs)
206208
text = response.choices[0].message.content
209+
if not text or response.failed:
210+
kwargs["caching"] = False
211+
response = litellm.completion(**kwargs)
212+
text = response.choices[0].message.content
213+
if not text or response.failed:
214+
# Just return an error response if the second attempt fails too
215+
return ModelResponse(
216+
text="Failed to get response from the API.", model=self.model, failed=True
217+
)
207218
return text
208219
except Exception as e:
209220
logger.warning(f"{type(e), e}")
210221
time.sleep(self.API_RETRY_SLEEP)
211-
raise Exception("Failed to get response from the API")
222+
return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)
212223

213224
results = []
214225
with ThreadPoolExecutor(100) as executor:

src/lighteval/models/model_output.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ModelResponse:
3333
generated_tokens: list[int] = field(default_factory=list) # model generations
3434
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
3535
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
36+
failed: bool = False
3637

3738
def get_result_for_eval(self):
3839
raise NotImplementedError()

0 commit comments

Comments
 (0)