-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Standalone human eval reproduction script
- Loading branch information
Showing
1 changed file
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/azureuser/.conda/envs/llm_env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"import tempfile\n", | ||
"sys.path.append('../')\n", | ||
"\n", | ||
"import torch\n", | ||
"from human_eval.data import stream_jsonl, write_jsonl, read_problems\n", | ||
"from human_eval.evaluation import evaluate_functional_correctness\n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | ||
"\n", | ||
"os.environ['TOKENIZERS_PARALLELISM'] = 'true'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"output_dir = tempfile.gettempdir()\n", | ||
"\n", | ||
"n_samples_per_task = 1\n", | ||
"batch_size = 32\n", | ||
"n_workers = 8\n", | ||
"\n", | ||
"max_gen_length = 512\n", | ||
"\n", | ||
"use_instruct_model = True\n", | ||
"model_size = '1.3b'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def cleanup_code(code: str, instruct_format: bool = False) -> str:\n", | ||
" \"\"\"\n", | ||
" Cleans up the generated code.\n", | ||
" \"\"\"\n", | ||
" if instruct_format:\n", | ||
" code = code.replace(\"\\r\", \"\")\n", | ||
" if \"```python\" in code:\n", | ||
" code_start_idx = code.index(\"```python\")\n", | ||
" code = code[code_start_idx:].replace(\"```python\", \"\").strip()\n", | ||
" end_idx = code.find(\"```\") if \"```\" in code else len(code)\n", | ||
" code = code[:end_idx].strip()\n", | ||
"\n", | ||
" else:\n", | ||
" stop_words = set([\"\\ndef\", \"\\nclass\", \"\\nif\", \"\\n#\", \"\\nprint\"])\n", | ||
" min_stop_idx = len(code)\n", | ||
" for stop_word in stop_words:\n", | ||
" stop_index = code.find(stop_word)\n", | ||
" if 0 <= stop_index < min_stop_idx:\n", | ||
" min_stop_idx = stop_index\n", | ||
" code = code[:min_stop_idx]\n", | ||
"\n", | ||
" return code" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/azureuser/.conda/envs/llm_env/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | ||
" warnings.warn(\n", | ||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"device = 'cuda'\n", | ||
"\n", | ||
"model_type = 'instruct' if use_instruct_model else 'base'\n", | ||
"model_name = f'deepseek-ai/deepseek-coder-{model_size}-{model_type}'\n", | ||
"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | ||
"tokenizer.padding_side = 'left'\n", | ||
"# tokenizer.pad_token = tokenizer.eos_token # to avoid an error\n", | ||
"model = AutoModelForCausalLM.from_pretrained(\n", | ||
" model_name, attn_implementation='flash_attention_2',\n", | ||
" torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True,\n", | ||
")\n", | ||
"model = torch.compile(model)\n", | ||
"model = model.eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"# Problems: 164\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"problems = read_problems()\n", | ||
"print(f'# Problems: {len(problems)}')\n", | ||
"\n", | ||
"problem_tuples = [(k, v['prompt']) for k, v in problems.items()]\n", | ||
"task_ids, prompts = zip(*problem_tuples)\n", | ||
"\n", | ||
"# Create lists of the input task ids and corresponding GenerateData objects as inputs\n", | ||
"input_tasks = [\n", | ||
" task_id\n", | ||
" for task_id in task_ids\n", | ||
" for _ in range(n_samples_per_task)\n", | ||
"]\n", | ||
"inputs = [\n", | ||
" prompt\n", | ||
" for prompt in prompts\n", | ||
" for _ in range(n_samples_per_task)\n", | ||
"]\n", | ||
"\n", | ||
"if use_instruct_model:\n", | ||
" instruct_template = \\\n", | ||
" \"Below is an instruction that describes a task, paired with an input that provides further context.\\n\" + \\\n", | ||
" \"Write a response that appropriately completes the request.\\n\\n### Instruction:\\nWrite a program to \" + \\\n", | ||
" \"perform the given task.\\n\\nInput:\\n{}\\n\\n### Response:\\n\"\n", | ||
" inputs = [instruct_template.format(prompt) for prompt in prompts]\n", | ||
"\n", | ||
"inputs = tokenizer(inputs, padding=True, return_tensors='pt').to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n", | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n", | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n", | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n", | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n", | ||
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"completions = []\n", | ||
"\n", | ||
"for i in range(0, len(inputs['input_ids']), batch_size):\n", | ||
" batch_inputs = {k: v[i:i+batch_size] for k, v in inputs.items()}\n", | ||
"\n", | ||
" with torch.no_grad():\n", | ||
" generated_ids = model.generate(**batch_inputs, max_new_tokens=max_gen_length)\n", | ||
" # generated_ids = model.generate(\n", | ||
" # **batch_inputs,\n", | ||
" # max_new_tokens = max_gen_length,\n", | ||
" # do_sample = False,\n", | ||
" # eos_token_id = tokenizer.eos_token_id,\n", | ||
" # pad_token_id = tokenizer.eos_token_id,\n", | ||
" # )\n", | ||
" \n", | ||
" completion_ids = generated_ids[:, batch_inputs['input_ids'].shape[1]:]\n", | ||
" batch_completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)\n", | ||
" completions.extend(batch_completions)\n", | ||
"\n", | ||
"cleaned_completions = [cleanup_code(c, use_instruct_model) for c in completions]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Reading samples...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"164it [00:00, 21024.72it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Running test suites...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"100%|██████████| 164/164 [00:25<00:00, 6.37it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Writing results to /tmp/human_eval_samples.jsonl_results.jsonl...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 164/164 [00:00<00:00, 59762.45it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'pass@1': 0.6524390243902439}\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"samples = [\n", | ||
" dict(task_id=task_id, completion=completion)\n", | ||
" for task_id, completion in zip(input_tasks, cleaned_completions)\n", | ||
"]\n", | ||
"\n", | ||
"# Write the results to a file\n", | ||
"filepath = os.path.join(output_dir, 'human_eval_samples.jsonl')\n", | ||
"os.makedirs(output_dir, exist_ok=True)\n", | ||
"write_jsonl(filepath, samples)\n", | ||
"\n", | ||
"print(evaluate_functional_correctness(filepath, k=[1], n_workers=n_workers, timeout=20))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Passed: 0.65\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Read the results\n", | ||
"results = list(stream_jsonl(filepath + '_results.jsonl'))\n", | ||
"passed = [r['passed'] for r in results]\n", | ||
"passed_frac = sum(passed) / len(passed)\n", | ||
"print(f'Passed: {passed_frac:.2f}')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "llm_env", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.19" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |