Skip to content

Commit

Permalink
Update no_trainer scripts with new Accelerate functionalities (#16617)
Browse files Browse the repository at this point in the history
Adds logging and save/loading to the Accelerate scripts

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
muellerzr and sgugger authored Apr 6, 2022
1 parent 10c15d2 commit febe42b
Show file tree
Hide file tree
Showing 9 changed files with 711 additions and 78 deletions.
89 changes: 80 additions & 9 deletions examples/pytorch/language-modeling/run_clm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()

# Sanity checks
Expand All @@ -208,7 +225,8 @@ def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -427,18 +445,10 @@ def group_texts(examples):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()

# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
Expand All @@ -453,6 +463,23 @@ def group_texts(examples):
num_training_steps=args.max_train_steps,
)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand All @@ -467,11 +494,38 @@ def group_texts(examples):
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
resume_step = None
path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path:
args.num_train_epochs -= int(path.replace("epoch_", ""))
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step

for epoch in range(args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
Expand All @@ -481,6 +535,10 @@ def group_texts(examples):
progress_bar.update(1)
completed_steps += 1

if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}")

if completed_steps >= args.max_train_steps:
break

Expand All @@ -502,6 +560,16 @@ def group_texts(examples):

logger.info(f"epoch {epoch}: perplexity: {perplexity}")

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand All @@ -512,6 +580,9 @@ def group_texts(examples):
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)

if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand Down
86 changes: 80 additions & 6 deletions examples/pytorch/language-modeling/run_mlm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()

# Sanity checks
Expand All @@ -219,7 +236,8 @@ def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -468,11 +486,6 @@ def group_texts(examples):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
Expand All @@ -494,6 +507,23 @@ def group_texts(examples):
num_training_steps=args.max_train_steps,
)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand All @@ -508,11 +538,38 @@ def group_texts(examples):
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
resume_step = None
path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path:
args.num_train_epochs -= int(path.replace("epoch_", ""))
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step

for epoch in range(args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
Expand All @@ -522,6 +579,10 @@ def group_texts(examples):
progress_bar.update(1)
completed_steps += 1

if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}")

if completed_steps >= args.max_train_steps:
break

Expand All @@ -543,6 +604,16 @@ def group_texts(examples):

logger.info(f"epoch {epoch}: perplexity: {perplexity}")

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand All @@ -553,6 +624,9 @@ def group_texts(examples):
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)

if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand Down
Loading

0 comments on commit febe42b

Please # to comment.