From 3c2f7391caeb2653eb763adcdfc6731096140711 Mon Sep 17 00:00:00 2001 From: "Philippe Laban (canny1)" Date: Fri, 2 Jul 2021 11:35:53 -0400 Subject: [PATCH] Adding example for running Keep it Simple model. --- run_keep_it_simple.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 run_keep_it_simple.py diff --git a/run_keep_it_simple.py b/run_keep_it_simple.py new file mode 100644 index 0000000..ad9d4b5 --- /dev/null +++ b/run_keep_it_simple.py @@ -0,0 +1,35 @@ +import utils_misc, argparse, os, utils_edits + +utils_misc.select_freer_gpu() +from model_generator import Generator +# from model_access_simplifier import AccessSimplifier + +MODELS_FOLDER = os.environ["MODELS_FOLDER"] + +parser = argparse.ArgumentParser() +parser.add_argument("--model_card", type=str, default="gpt2-medium", help="Either `gpt2` or `gpt2-medium`") +parser.add_argument("--model_file", type=str, required=True, help="Use for example `gpt2_med_keep_it_simple.bin` provided in the codebase.") + +args = parser.parse_args() + +model = Generator(args.model_card, max_output_length=90, device='cuda') + +if len(args.model_file) > 0: + model.reload(args.model_file) +model.eval() + +paragraph = """NASA's Curiosity rover just celebrated a major milestone — 3,000 days on the surface of Mars. + To mark the occasion, the space agency has released a stunning new panorama of the red planet, captured by the rover.""" + +model_output = model.generate([paragraph], num_runs=8, sample=True)[0] + +print("ORIGINAL TEXT:") +print(paragraph) + +print(utils_edits.show_diff_word("Legend: Deletions are in red. Additions are in.", "Legend: Deletions are in. Additions are in green.")) + +print(model_output) + +for candidate in model_output: + print("----") + print(utils_edits.show_diff_word(paragraph, candidate["output_text"]))