Skip to content

Commit 3deed72

Browse files
icsl-Jeonsayakpaullinoytsaban
authored
Handling mixed precision for dreambooth flux lora training (#9565)
Handling mixed precision and add unwarp Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent 7ffbc25 commit 3deed72

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def log_validation(
177177
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
178178
f" {args.validation_prompt}."
179179
)
180-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
180+
pipeline = pipeline.to(accelerator.device)
181181
pipeline.set_progress_bar_config(disable=True)
182182

183183
# run inference
@@ -1706,7 +1706,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17061706
)
17071707

17081708
# handle guidance
1709-
if transformer.config.guidance_embeds:
1709+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
17101710
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
17111711
guidance = guidance.expand(model_input.shape[0])
17121712
else:
@@ -1819,6 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18191819
# create pipeline
18201820
if not args.train_text_encoder:
18211821
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
1822+
text_encoder_one.to(weight_dtype)
1823+
text_encoder_two.to(weight_dtype)
18221824
pipeline = FluxPipeline.from_pretrained(
18231825
args.pretrained_model_name_or_path,
18241826
vae=vae,

0 commit comments

Comments
 (0)