diff --git a/eval/run_eval.py b/eval/run_eval.py index 2062899..5b23fe6 100644 --- a/eval/run_eval.py +++ b/eval/run_eval.py @@ -397,6 +397,10 @@ def run_eval(config: DictConfig, model_provider: ModelProvider): else: logger.info("No benchmarks enabled.") + # Make a random seed that will be used for both pre- and post-finetune eval + # This is required to keep the FIM examples the same for both runs + eval_seed = random.randint(0, 2**32 - 1) + # Load name of all eval tasks task_metrics = [] for task_name, task_info in load_task_info(): @@ -409,6 +413,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider): ### Run eval ### logger.info(f"Running pre-finetune eval...") + set_seed(eval_seed) eval_metrics = run_task_eval(config, task_info, model, tokenizer) eval_metrics['task_name'] = task_name eval_metrics['finetuned'] = False @@ -432,6 +437,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider): ### Run eval again ### logger.info(f"Running post-finetune eval...") + set_seed(eval_seed) eval_metrics = run_task_eval(config, task_info, model, tokenizer) eval_metrics['task_name'] = task_name eval_metrics['finetuned'] = True @@ -522,6 +528,14 @@ def configure_logging(): lite_llm.setLevel(logging.INFO) +def set_seed(seed: Optional[int] = None): + """Set the random seed for reproducibility.""" + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @hydra.main(version_base=None, config_path='../src/conf', config_name='eval') def main(config: DictConfig): @@ -535,11 +549,8 @@ def main(config: DictConfig): model_provider = ModelProvider.get_instance(config.model) # Set random seed - if config.get('seed') is not None: - random.seed(config.seed) - np.random.seed(config.seed) - torch.manual_seed(config.seed) - + set_seed(config.get('seed')) + # Run eval logger.info("Running eval...") eval_results = run_eval(config, model_provider)