-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_gpt_inference.py
48 lines (40 loc) · 1.75 KB
/
run_gpt_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from contextlib import redirect_stderr, redirect_stdout
from pathlib import Path
from confection import Config
from llm_classification import run_config
def run_log_config(config: Config, log_path: str):
"""Runs config and redirects stdout to a file."""
with open(log_path, "w") as log_file:
with redirect_stdout(log_file):
with redirect_stderr(log_file):
run_config(config)
def main():
default_config = Config().from_disk("configs/default_gpt_config.cfg")
log_dir = Path("gpt_runs_logfiles/")
log_dir.mkdir(exist_ok=True)
print("Collecting configs and log paths.")
for model in ["gpt-4", "gpt-3.5-turbo"]:
for prompt_type in ["custom", "generic"]:
for column in ["political", "exemplar"]:
for task in ["zero-shot", "few-shot"]:
config = default_config.copy()
config["paths"]["out_dir"] = f"predictions_{prompt_type}/"
if prompt_type == "custom":
config["paths"][
"prompt_file"
] = f"prompts/gpt_{task}_{column}.txt"
config["model"]["name"] = model
config["model"]["task"] = task
config["inference"]["y_column"] = column
print(
"------------------------------\n"
f"Running Inference with {model}\n"
f" - task: {task}\n"
f" - prompt: {prompt_type}\n"
f" - column: {column}\n"
"------------------------------"
)
run_config(config)
print("DONE")
if __name__ == "__main__":
main()