Skip to content

Commit 94fc5a2

Browse files
authored
Implemented the possibility to load predictions from details files and continue evaluating from there (#488)
* Implemented the possibility to load predictions from details files and continue evaluating from there. * Run model as fallback when no details can be loaded. * Improved loading speed and added more useful error messages. * Fixed typo. * Fixed gnarly bug with details loading to prevent loading too many examples. * Unpacking predictions to fix issue with weirdly saved predictions. * Made bulk loading easier by also allowing first timestamp more generally. * Made loading details more robust against tensors being saved in the details files.
1 parent 48d0c28 commit 94fc5a2

File tree

5 files changed

+238
-7
lines changed

5 files changed

+238
-7
lines changed

src/lighteval/logging/evaluation_tracker.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,45 @@ def save_results(self, date_id: str, results_dict: dict):
235235
with self.fs.open(output_results_file, "w") as f:
236236
f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False))
237237

238-
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
238+
def _get_details_sub_folder(self, date_id: str):
239239
output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name
240-
output_dir_details_sub_folder = output_dir_details / date_id
240+
if date_id in ["first", "last"]:
241+
# Get all folders in output_dir_details
242+
if not self.fs.exists(output_dir_details):
243+
raise FileNotFoundError(f"Details directory {output_dir_details} does not exist")
244+
245+
# List all folders and filter out files
246+
folders = [f["name"] for f in self.fs.listdir(output_dir_details) if f["type"] == "directory"]
247+
248+
if not folders:
249+
raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}")
250+
251+
# Parse timestamps and get first or last
252+
date_id = max(folders) if date_id == "last" else min(folders)
253+
return output_dir_details / date_id
254+
255+
def load_details_datasets(self, date_id: str, task_names: list[str]) -> dict[str, Dataset]:
256+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
257+
logger.info(f"Loading details from {output_dir_details_sub_folder}")
258+
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
259+
details_datasets = {}
260+
for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")):
261+
task_name = Path(file).stem.replace("details_", "").replace(f"_{date_id}", "")
262+
if "|".join(task_name.split("|")[:-1]) not in task_names:
263+
logger.info(f"Skipping {task_name} because it is not in the task_names list")
264+
continue
265+
dataset = load_dataset("parquet", data_files=file, split="train")
266+
details_datasets[task_name] = dataset
267+
268+
for task_name in task_names:
269+
if not any(task_name.startswith(task_name) for task_name in details_datasets.keys()):
270+
raise ValueError(
271+
f"Task {task_name} not found in details datasets. Check the tasks to be evaluated or the date_id used to load the details ({date_id})."
272+
)
273+
return details_datasets
274+
275+
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
276+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
241277
self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True)
242278
logger.info(f"Saving details to {output_dir_details_sub_folder}")
243279
for task_name, dataset in details_datasets.items():

src/lighteval/main_accelerate.py

+4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def accelerate( # noqa C901
6767
num_fewshot_seeds: Annotated[
6868
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6969
] = 1,
70+
load_responses_from_details_date_id: Annotated[
71+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
72+
] = None,
7073
# === saving ===
7174
output_dir: Annotated[
7275
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -137,6 +140,7 @@ def accelerate( # noqa C901
137140
max_samples=max_samples,
138141
use_chat_template=use_chat_template,
139142
system_prompt=system_prompt,
143+
load_responses_from_details_date_id=load_responses_from_details_date_id,
140144
)
141145

142146
# TODO (nathan): better handling of model_args

src/lighteval/main_endpoint.py

+12
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def inference_endpoint(
179179
num_fewshot_seeds: Annotated[
180180
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
181181
] = 1,
182+
load_responses_from_details_date_id: Annotated[
183+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
184+
] = None,
182185
# === saving ===
183186
output_dir: Annotated[
184187
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -247,6 +250,7 @@ def inference_endpoint(
247250
max_samples=max_samples,
248251
use_chat_template=use_chat_template,
249252
system_prompt=system_prompt,
253+
load_responses_from_details_date_id=load_responses_from_details_date_id,
250254
)
251255
pipeline = Pipeline(
252256
tasks=tasks,
@@ -292,6 +296,9 @@ def tgi(
292296
num_fewshot_seeds: Annotated[
293297
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
294298
] = 1,
299+
load_responses_from_details_date_id: Annotated[
300+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
301+
] = None,
295302
# === saving ===
296303
output_dir: Annotated[
297304
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -355,6 +362,7 @@ def tgi(
355362
max_samples=max_samples,
356363
use_chat_template=use_chat_template,
357364
system_prompt=system_prompt,
365+
load_responses_from_details_date_id=load_responses_from_details_date_id,
358366
)
359367
pipeline = Pipeline(
360368
tasks=tasks,
@@ -400,6 +408,9 @@ def litellm(
400408
num_fewshot_seeds: Annotated[
401409
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
402410
] = 1,
411+
load_responses_from_details_date_id: Annotated[
412+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
413+
] = None,
403414
# === saving ===
404415
output_dir: Annotated[
405416
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -464,6 +475,7 @@ def litellm(
464475
max_samples=max_samples,
465476
use_chat_template=use_chat_template,
466477
system_prompt=system_prompt,
478+
load_responses_from_details_date_id=load_responses_from_details_date_id,
467479
)
468480
pipeline = Pipeline(
469481
tasks=tasks,

src/lighteval/main_vllm.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def vllm(
6363
num_fewshot_seeds: Annotated[
6464
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6565
] = 1,
66+
load_responses_from_details_date_id: Annotated[
67+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
68+
] = None,
6669
# === saving ===
6770
output_dir: Annotated[
6871
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -124,6 +127,7 @@ def vllm(
124127
max_samples=max_samples,
125128
use_chat_template=use_chat_template,
126129
system_prompt=system_prompt,
130+
load_responses_from_details_date_id=load_responses_from_details_date_id,
127131
)
128132

129133
if model_args.endswith(".yaml"):

src/lighteval/pipeline.py

+180-5
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,33 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import ast
2324
import collections
2425
import os
2526
import random
27+
import re
2628
import shutil
2729
from contextlib import nullcontext
2830
from dataclasses import dataclass, field
2931
from datetime import timedelta
3032
from enum import Enum, auto
3133

3234
import numpy as np
35+
from tqdm import tqdm
3336

3437
from lighteval.logging.evaluation_tracker import EvaluationTracker
3538
from lighteval.metrics.utils.metric_utils import MetricCategory
3639
from lighteval.models.model_loader import TransformersModel, load_model
37-
from lighteval.models.model_output import ModelResponse
40+
from lighteval.models.model_output import (
41+
GenerativeMultiturnResponse,
42+
GenerativeResponse,
43+
LoglikelihoodResponse,
44+
LoglikelihoodSingleTokenResponse,
45+
ModelResponse,
46+
)
3847
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
3948
from lighteval.tasks.registry import Registry, taskinfo_selector
40-
from lighteval.tasks.requests import SampleUid
49+
from lighteval.tasks.requests import RequestType, SampleUid
4150
from lighteval.utils.imports import (
4251
NO_ACCELERATE_ERROR_MSG,
4352
NO_NANOTRON_ERROR_MSG,
@@ -95,6 +104,7 @@ class PipelineParameters:
95104
max_samples: int | None = None
96105
use_chat_template: bool = False
97106
system_prompt: str | None = None
107+
load_responses_from_details_date_id: str | None = None
98108

99109
def __post_init__(self): # noqa C901
100110
if self.launcher_type == ParallelismManager.ACCELERATE:
@@ -245,7 +255,17 @@ def evaluate(self):
245255
config=self.model_config,
246256
)
247257

248-
sample_id_to_responses = self._run_model()
258+
if self.pipeline_parameters.load_responses_from_details_date_id:
259+
try:
260+
sample_id_to_responses = self._load_responses_from_details()
261+
except FileNotFoundError as e:
262+
logger.warning(
263+
f"No responses found for {self.pipeline_parameters.load_responses_from_details_date_id} in details directory: {e}. Running model instead."
264+
)
265+
sample_id_to_responses = self._run_model()
266+
else:
267+
sample_id_to_responses = self._run_model()
268+
249269
self._compute_metrics(sample_id_to_responses)
250270

251271
if self.is_main_process():
@@ -261,6 +281,158 @@ def evaluate(self):
261281
except OSError:
262282
pass
263283

284+
def _unpack(self, x):
285+
if isinstance(x, str):
286+
return x
287+
elif isinstance(x, (list, tuple)):
288+
return self._unpack(x[0])
289+
else:
290+
raise ValueError(f"Unknown type {type(x)} of prediction {x}")
291+
292+
def _parse_tensor_string(self, tensor_string):
293+
"""
294+
Convert a string containing PyTorch-like `tensor([...], device='cuda:0', ...)`
295+
into a Python list (or nested lists) of numbers.
296+
297+
Example:
298+
"[tensor([1, 2, 3], device='cuda:0'), tensor([[4,5],[6,7]], dtype=torch.int64)]"
299+
-> [[1, 2, 3], [[4, 5], [6, 7]]]
300+
"""
301+
302+
# Regex explanation:
303+
# - tensor\(\s*: Matches "tensor(" (possibly with spaces after), literally.
304+
# - (.*?): Captures everything lazily into group(1), until the first subsequent part matches.
305+
# We rely on the next pattern to anchor the end of this capture.
306+
# - \): The literal closing parenthesis, but we anchor the match by ignoring
307+
# further arguments (device=..., dtype=..., etc.) inside.
308+
#
309+
# The tricky part: a tensor might look like
310+
# tensor([ ... ], device='cuda:0', dtype=torch.int64)
311+
# so the bracket portion is `[ ... ]`, but it can have newlines, etc.
312+
#
313+
# We'll handle that by first capturing the entire content up to the final parenthesis,
314+
# then parse out the bracket portion. This can be done in a function-based re.sub.
315+
316+
pattern = re.compile(
317+
r"tensor\s*\(\s*(.*?)\s*\)", # capture everything inside tensor(...)
318+
flags=re.DOTALL,
319+
)
320+
321+
def tensor_replacer(match):
322+
inside = match.group(1).strip()
323+
# `inside` might look like: [1, 2, 3], device='cuda:0'
324+
# or:
325+
# [
326+
# 1, 2, 3,
327+
# 4, 5, ...
328+
# ], device='cuda:0', dtype=torch.int64
329+
#
330+
# 1) Extract the bracketed array portion: the first [ ... ] block
331+
# which might be multi-line. We'll use another regex for that.
332+
333+
# We look for the bracketed portion from the first '[' to its matching ']'.
334+
# Because the inside can be multi-line, we use DOTALL. But we still need
335+
# to ensure we don't accidentally go beyond the matching bracket.
336+
#
337+
# A robust approach to properly match brackets can be done with a small parser,
338+
# but for typical well-formed strings, a lazy match of the form
339+
# r"\[.*?\]" DOTALL often suffices, assuming no nested brackets inside.
340+
341+
bracket_pattern = re.compile(r"\[.*?\]", re.DOTALL)
342+
bracket_match = bracket_pattern.search(inside)
343+
if not bracket_match:
344+
# If we fail to find a bracket, just return something safe.
345+
# This means the string didn't match the expected format.
346+
return "[]"
347+
348+
# The bracketed portion (e.g. "[1, 2, 3\n, 4]").
349+
bracketed_content = bracket_match.group(0)
350+
351+
# Return just the bracketed content,
352+
# effectively replacing "tensor(...)" with "[...]".
353+
return bracketed_content
354+
355+
# Step 1: Replace every `tensor(...)` occurrence with just the bracketed list.
356+
processed = pattern.sub(tensor_replacer, tensor_string)
357+
358+
# Step 2: Now we can safely parse the result with literal_eval.
359+
# If there's still something weird, it may throw ValueError.
360+
try:
361+
return ast.literal_eval(processed)
362+
except Exception as e:
363+
raise ValueError(f"Failed to parse after preprocessing. " f"Processed string:\n{processed}\n\nError: {e}")
364+
365+
def _load_responses_from_details(self):
366+
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
367+
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
368+
369+
request_types = list(self.requests.keys())
370+
if len(request_types) > 1:
371+
raise ValueError(
372+
"Loading responses from details when there are multiple request types is currently not supported"
373+
)
374+
model_response_type = self._get_model_response_type(request_types[0])
375+
376+
details_datasets = self.evaluation_tracker.load_details_datasets(
377+
self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list
378+
)
379+
380+
for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"):
381+
task: LightevalTask = self._get_task(task_name)
382+
num_samples = len(set(dataset["specifics"]))
383+
max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples
384+
if num_samples > max_samples:
385+
logger.warning(
386+
f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}"
387+
)
388+
num_samples = self.pipeline_parameters.max_samples
389+
390+
predictions = [self._unpack(ast.literal_eval(p)) for p in dataset["predictions"][:num_samples]]
391+
input_tokens = [self._parse_tensor_string(t) for t in dataset["input_tokens"][:num_samples]]
392+
cont_tokens = [self._parse_tensor_string(t) for t in dataset["cont_tokens"][:num_samples]]
393+
truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]]
394+
padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]
395+
396+
if model_response_type == GenerativeResponse:
397+
logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]]
398+
399+
for metric_category, has_metric_category in task.has_metric_category.items():
400+
if not has_metric_category:
401+
continue
402+
403+
for idx in range(num_samples):
404+
kwargs = {
405+
"result": predictions[idx],
406+
"input_tokens": input_tokens[idx],
407+
"generated_tokens": cont_tokens[idx],
408+
"truncated_tokens_count": truncated[idx],
409+
"padded_tokens_count": padded[idx],
410+
}
411+
if model_response_type == GenerativeResponse:
412+
kwargs["logits"] = logits[idx]
413+
414+
response = model_response_type(**kwargs)
415+
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
416+
return sample_id_to_responses
417+
418+
def _get_model_response_type(self, request_type):
419+
if request_type == RequestType.LOGLIKELIHOOD:
420+
model_response_type = LoglikelihoodResponse
421+
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
422+
model_response_type = LoglikelihoodSingleTokenResponse
423+
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
424+
model_response_type = LoglikelihoodResponse
425+
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
426+
model_response_type = GenerativeMultiturnResponse
427+
elif request_type == RequestType.GREEDY_UNTIL:
428+
model_response_type = GenerativeResponse
429+
else:
430+
raise ValueError(
431+
f"Loading responses from details for request type {request_type} is currently not supported"
432+
)
433+
434+
return model_response_type
435+
264436
def _run_model(self):
265437
# Running all requests depending on the model call type (log likelihood, generative, ...)
266438
# to be able to batch them
@@ -283,6 +455,10 @@ def _run_model(self):
283455

284456
return sample_id_to_responses
285457

458+
def _get_task(self, task_name: str):
459+
short_task_name = task_name.rsplit("|", 1)[0]
460+
return self.task_dict[short_task_name]
461+
286462
def _compute_metrics(self, sample_id_to_responses):
287463
# To compute the metrics we first group the samples and task and then by metrics.
288464
# This way we can batch the metrics computation for each task and metric category
@@ -307,8 +483,7 @@ def _compute_metrics(self, sample_id_to_responses):
307483
task_metric_category_groups[sample_id.task_name][metric_category]["docs"].append(self.docs[sample_id])
308484

309485
for task_name, samples_per_metric in task_metric_category_groups.items():
310-
short_task_name = task_name.rsplit("|", 1)[0]
311-
task: LightevalTask = self.task_dict[short_task_name]
486+
task: LightevalTask = self._get_task(task_name)
312487

313488
for metric_category, samples in samples_per_metric.items():
314489
sample_ids = samples["ids"]

0 commit comments

Comments
 (0)