Skip to content

Commit

Permalink
fix: Consistent FIM examples between evals
Browse files Browse the repository at this point in the history
  • Loading branch information
ejmejm committed Jun 5, 2024
1 parent 50ff01c commit 79bf926
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions eval/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,10 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
else:
logger.info("No benchmarks enabled.")

# Make a random seed that will be used for both pre- and post-finetune eval
# This is required to keep the FIM examples the same for both runs
eval_seed = random.randint(0, 2**32 - 1)

# Load name of all eval tasks
task_metrics = []
for task_name, task_info in load_task_info():
Expand All @@ -409,6 +413,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
### Run eval ###

logger.info(f"Running pre-finetune eval...")
set_seed(eval_seed)
eval_metrics = run_task_eval(config, task_info, model, tokenizer)
eval_metrics['task_name'] = task_name
eval_metrics['finetuned'] = False
Expand All @@ -432,6 +437,7 @@ def run_eval(config: DictConfig, model_provider: ModelProvider):
### Run eval again ###

logger.info(f"Running post-finetune eval...")
set_seed(eval_seed)
eval_metrics = run_task_eval(config, task_info, model, tokenizer)
eval_metrics['task_name'] = task_name
eval_metrics['finetuned'] = True
Expand Down Expand Up @@ -522,6 +528,14 @@ def configure_logging():
lite_llm.setLevel(logging.INFO)


def set_seed(seed: Optional[int] = None):
"""Set the random seed for reproducibility."""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


@hydra.main(version_base=None, config_path='../src/conf', config_name='eval')
def main(config: DictConfig):

Expand All @@ -535,11 +549,8 @@ def main(config: DictConfig):
model_provider = ModelProvider.get_instance(config.model)

# Set random seed
if config.get('seed') is not None:
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)

set_seed(config.get('seed'))

# Run eval
logger.info("Running eval...")
eval_results = run_eval(config, model_provider)
Expand Down

0 comments on commit 79bf926

Please # to comment.