Skip to content

Commit

Permalink
Update train_lora.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zejunwang1 authored Jul 19, 2023
1 parent 93febd1 commit d70a9ac
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def train():
# cast all non INT8 parameters to fp32
model = prepare_model_for_kbit_training(model,
use_gradient_checkpointing=training_args.gradient_checkpointing)

if training_args.gradient_checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

# Get our peft model and print the number of trainable parameters
checkpoint_dir = training_args.resume_from_checkpoint
Expand Down

0 comments on commit d70a9ac

Please # to comment.