20
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
# SOFTWARE.
22
22
23
+ import ast
23
24
import collections
24
25
import os
25
26
import random
27
+ import re
26
28
import shutil
27
29
from contextlib import nullcontext
28
30
from dataclasses import dataclass , field
29
31
from datetime import timedelta
30
32
from enum import Enum , auto
31
33
32
34
import numpy as np
35
+ from tqdm import tqdm
33
36
34
37
from lighteval .logging .evaluation_tracker import EvaluationTracker
35
38
from lighteval .metrics .utils .metric_utils import MetricCategory
36
39
from lighteval .models .model_loader import TransformersModel , load_model
37
- from lighteval .models .model_output import ModelResponse
40
+ from lighteval .models .model_output import (
41
+ GenerativeMultiturnResponse ,
42
+ GenerativeResponse ,
43
+ LoglikelihoodResponse ,
44
+ LoglikelihoodSingleTokenResponse ,
45
+ ModelResponse ,
46
+ )
38
47
from lighteval .tasks .lighteval_task import LightevalTask , create_requests_from_tasks
39
48
from lighteval .tasks .registry import Registry , taskinfo_selector
40
- from lighteval .tasks .requests import SampleUid
49
+ from lighteval .tasks .requests import RequestType , SampleUid
41
50
from lighteval .utils .imports import (
42
51
NO_ACCELERATE_ERROR_MSG ,
43
52
NO_NANOTRON_ERROR_MSG ,
@@ -95,6 +104,7 @@ class PipelineParameters:
95
104
max_samples : int | None = None
96
105
use_chat_template : bool = False
97
106
system_prompt : str | None = None
107
+ load_responses_from_details_date_id : str | None = None
98
108
99
109
def __post_init__ (self ): # noqa C901
100
110
if self .launcher_type == ParallelismManager .ACCELERATE :
@@ -245,7 +255,17 @@ def evaluate(self):
245
255
config = self .model_config ,
246
256
)
247
257
248
- sample_id_to_responses = self ._run_model ()
258
+ if self .pipeline_parameters .load_responses_from_details_date_id :
259
+ try :
260
+ sample_id_to_responses = self ._load_responses_from_details ()
261
+ except FileNotFoundError as e :
262
+ logger .warning (
263
+ f"No responses found for { self .pipeline_parameters .load_responses_from_details_date_id } in details directory: { e } . Running model instead."
264
+ )
265
+ sample_id_to_responses = self ._run_model ()
266
+ else :
267
+ sample_id_to_responses = self ._run_model ()
268
+
249
269
self ._compute_metrics (sample_id_to_responses )
250
270
251
271
if self .is_main_process ():
@@ -261,6 +281,158 @@ def evaluate(self):
261
281
except OSError :
262
282
pass
263
283
284
+ def _unpack (self , x ):
285
+ if isinstance (x , str ):
286
+ return x
287
+ elif isinstance (x , (list , tuple )):
288
+ return self ._unpack (x [0 ])
289
+ else :
290
+ raise ValueError (f"Unknown type { type (x )} of prediction { x } " )
291
+
292
+ def _parse_tensor_string (self , tensor_string ):
293
+ """
294
+ Convert a string containing PyTorch-like `tensor([...], device='cuda:0', ...)`
295
+ into a Python list (or nested lists) of numbers.
296
+
297
+ Example:
298
+ "[tensor([1, 2, 3], device='cuda:0'), tensor([[4,5],[6,7]], dtype=torch.int64)]"
299
+ -> [[1, 2, 3], [[4, 5], [6, 7]]]
300
+ """
301
+
302
+ # Regex explanation:
303
+ # - tensor\(\s*: Matches "tensor(" (possibly with spaces after), literally.
304
+ # - (.*?): Captures everything lazily into group(1), until the first subsequent part matches.
305
+ # We rely on the next pattern to anchor the end of this capture.
306
+ # - \): The literal closing parenthesis, but we anchor the match by ignoring
307
+ # further arguments (device=..., dtype=..., etc.) inside.
308
+ #
309
+ # The tricky part: a tensor might look like
310
+ # tensor([ ... ], device='cuda:0', dtype=torch.int64)
311
+ # so the bracket portion is `[ ... ]`, but it can have newlines, etc.
312
+ #
313
+ # We'll handle that by first capturing the entire content up to the final parenthesis,
314
+ # then parse out the bracket portion. This can be done in a function-based re.sub.
315
+
316
+ pattern = re .compile (
317
+ r"tensor\s*\(\s*(.*?)\s*\)" , # capture everything inside tensor(...)
318
+ flags = re .DOTALL ,
319
+ )
320
+
321
+ def tensor_replacer (match ):
322
+ inside = match .group (1 ).strip ()
323
+ # `inside` might look like: [1, 2, 3], device='cuda:0'
324
+ # or:
325
+ # [
326
+ # 1, 2, 3,
327
+ # 4, 5, ...
328
+ # ], device='cuda:0', dtype=torch.int64
329
+ #
330
+ # 1) Extract the bracketed array portion: the first [ ... ] block
331
+ # which might be multi-line. We'll use another regex for that.
332
+
333
+ # We look for the bracketed portion from the first '[' to its matching ']'.
334
+ # Because the inside can be multi-line, we use DOTALL. But we still need
335
+ # to ensure we don't accidentally go beyond the matching bracket.
336
+ #
337
+ # A robust approach to properly match brackets can be done with a small parser,
338
+ # but for typical well-formed strings, a lazy match of the form
339
+ # r"\[.*?\]" DOTALL often suffices, assuming no nested brackets inside.
340
+
341
+ bracket_pattern = re .compile (r"\[.*?\]" , re .DOTALL )
342
+ bracket_match = bracket_pattern .search (inside )
343
+ if not bracket_match :
344
+ # If we fail to find a bracket, just return something safe.
345
+ # This means the string didn't match the expected format.
346
+ return "[]"
347
+
348
+ # The bracketed portion (e.g. "[1, 2, 3\n, 4]").
349
+ bracketed_content = bracket_match .group (0 )
350
+
351
+ # Return just the bracketed content,
352
+ # effectively replacing "tensor(...)" with "[...]".
353
+ return bracketed_content
354
+
355
+ # Step 1: Replace every `tensor(...)` occurrence with just the bracketed list.
356
+ processed = pattern .sub (tensor_replacer , tensor_string )
357
+
358
+ # Step 2: Now we can safely parse the result with literal_eval.
359
+ # If there's still something weird, it may throw ValueError.
360
+ try :
361
+ return ast .literal_eval (processed )
362
+ except Exception as e :
363
+ raise ValueError (f"Failed to parse after preprocessing. " f"Processed string:\n { processed } \n \n Error: { e } " )
364
+
365
+ def _load_responses_from_details (self ):
366
+ logger .info ("--- LOADING RESPONSES FROM DETAILS ---" )
367
+ sample_id_to_responses : dict [(SampleUid , MetricCategory ), list [ModelResponse ]] = collections .defaultdict (list )
368
+
369
+ request_types = list (self .requests .keys ())
370
+ if len (request_types ) > 1 :
371
+ raise ValueError (
372
+ "Loading responses from details when there are multiple request types is currently not supported"
373
+ )
374
+ model_response_type = self ._get_model_response_type (request_types [0 ])
375
+
376
+ details_datasets = self .evaluation_tracker .load_details_datasets (
377
+ self .pipeline_parameters .load_responses_from_details_date_id , self .task_names_list
378
+ )
379
+
380
+ for task_name , dataset in tqdm (details_datasets .items (), desc = "Loading responses from details for tasks" ):
381
+ task : LightevalTask = self ._get_task (task_name )
382
+ num_samples = len (set (dataset ["specifics" ]))
383
+ max_samples = self .pipeline_parameters .max_samples if self .pipeline_parameters .max_samples else num_samples
384
+ if num_samples > max_samples :
385
+ logger .warning (
386
+ f"Skipping { num_samples - max_samples } samples for { task_name } when loading responses from details because max_samples is set to { max_samples } "
387
+ )
388
+ num_samples = self .pipeline_parameters .max_samples
389
+
390
+ predictions = [self ._unpack (ast .literal_eval (p )) for p in dataset ["predictions" ][:num_samples ]]
391
+ input_tokens = [self ._parse_tensor_string (t ) for t in dataset ["input_tokens" ][:num_samples ]]
392
+ cont_tokens = [self ._parse_tensor_string (t ) for t in dataset ["cont_tokens" ][:num_samples ]]
393
+ truncated = [ast .literal_eval (t )[0 ] for t in dataset ["truncated" ][:num_samples ]]
394
+ padded = [ast .literal_eval (p )[0 ] for p in dataset ["padded" ][:num_samples ]]
395
+
396
+ if model_response_type == GenerativeResponse :
397
+ logits = [ast .literal_eval (p ) for p in dataset ["pred_logits" ][:num_samples ]]
398
+
399
+ for metric_category , has_metric_category in task .has_metric_category .items ():
400
+ if not has_metric_category :
401
+ continue
402
+
403
+ for idx in range (num_samples ):
404
+ kwargs = {
405
+ "result" : predictions [idx ],
406
+ "input_tokens" : input_tokens [idx ],
407
+ "generated_tokens" : cont_tokens [idx ],
408
+ "truncated_tokens_count" : truncated [idx ],
409
+ "padded_tokens_count" : padded [idx ],
410
+ }
411
+ if model_response_type == GenerativeResponse :
412
+ kwargs ["logits" ] = logits [idx ]
413
+
414
+ response = model_response_type (** kwargs )
415
+ sample_id_to_responses [(SampleUid (task_name , f"{ idx } _{ 0 } " ), metric_category )] = [response ]
416
+ return sample_id_to_responses
417
+
418
+ def _get_model_response_type (self , request_type ):
419
+ if request_type == RequestType .LOGLIKELIHOOD :
420
+ model_response_type = LoglikelihoodResponse
421
+ elif request_type == RequestType .LOGLIKELIHOOD_SINGLE_TOKEN :
422
+ model_response_type = LoglikelihoodSingleTokenResponse
423
+ elif request_type == RequestType .LOGLIKELIHOOD_ROLLING :
424
+ model_response_type = LoglikelihoodResponse
425
+ elif request_type == RequestType .GREEDY_UNTIL_MULTI_TURN :
426
+ model_response_type = GenerativeMultiturnResponse
427
+ elif request_type == RequestType .GREEDY_UNTIL :
428
+ model_response_type = GenerativeResponse
429
+ else :
430
+ raise ValueError (
431
+ f"Loading responses from details for request type { request_type } is currently not supported"
432
+ )
433
+
434
+ return model_response_type
435
+
264
436
def _run_model (self ):
265
437
# Running all requests depending on the model call type (log likelihood, generative, ...)
266
438
# to be able to batch them
@@ -283,6 +455,10 @@ def _run_model(self):
283
455
284
456
return sample_id_to_responses
285
457
458
+ def _get_task (self , task_name : str ):
459
+ short_task_name = task_name .rsplit ("|" , 1 )[0 ]
460
+ return self .task_dict [short_task_name ]
461
+
286
462
def _compute_metrics (self , sample_id_to_responses ):
287
463
# To compute the metrics we first group the samples and task and then by metrics.
288
464
# This way we can batch the metrics computation for each task and metric category
@@ -307,8 +483,7 @@ def _compute_metrics(self, sample_id_to_responses):
307
483
task_metric_category_groups [sample_id .task_name ][metric_category ]["docs" ].append (self .docs [sample_id ])
308
484
309
485
for task_name , samples_per_metric in task_metric_category_groups .items ():
310
- short_task_name = task_name .rsplit ("|" , 1 )[0 ]
311
- task : LightevalTask = self .task_dict [short_task_name ]
486
+ task : LightevalTask = self ._get_task (task_name )
312
487
313
488
for metric_category , samples in samples_per_metric .items ():
314
489
sample_ids = samples ["ids" ]
0 commit comments