diff --git a/train_gpt.py b/train_gpt.py index 2b069ed..f7df125 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -23,7 +23,7 @@ # CONFIG ########################################### -parser = argparse.ArgumentParser(description="Train VQGAN") +parser = argparse.ArgumentParser(description="Train GPT") parser.add_argument("--ds", type=str, default="coco", help="dataset name") parser.add_argument("--gpt", type=str, default="gpt2_medium", help="GPT model") parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")