From 89a3af6c15ffc0ebe68d44d3c0805d0c6b2ebbcf Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Wed, 18 Jun 2025 16:08:12 +0200 Subject: [PATCH 1/7] add example --- .../train_dreambooth_lora_flux_nano.py | 338 ++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py new file mode 100644 index 000000000000..963cfd69c3f4 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -0,0 +1,338 @@ +import copy +import logging +import math +import os +from pathlib import Path +import shutil + +import numpy as np +import pandas as pd +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm + +import diffusers +from diffusers import ( + AutoencoderKL, BitsAndBytesConfig, FlowMatchEulerDiscreteScheduler, + FluxPipeline, FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + cast_training_params, compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, free_memory, +) +from diffusers.utils import convert_unet_state_dict_to_peft, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module + +logger = get_logger(__name__) + +class DreamBoothDataset(Dataset): + def __init__(self, data_df_path, dataset_name, width, height, max_sequence_length=77): + self.width, self.height, self.max_sequence_length = width, height, max_sequence_length + self.data_df_path = Path(data_df_path) + if not self.data_df_path.exists(): + raise ValueError("`data_df_path` doesn't exists.") + + dataset = load_dataset(dataset_name, split="train") + self.instance_images = [sample["image"] for sample in dataset] + self.image_hashes = [insecure_hashlib.sha256(img.tobytes()).hexdigest() for img in self.instance_images] + self.pixel_values = self._apply_transforms() + self.data_dict = self._map_embeddings() + self._length = len(self.instance_images) + + def __len__(self): + return self._length + + def __getitem__(self, index): + idx = index % len(self.instance_images) + hash_key = self.image_hashes[idx] + prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[hash_key] + return { + "instance_images": self.pixel_values[idx], + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + } + + def _apply_transforms(self): + transform = transforms.Compose([ + transforms.Resize((self.height, self.width), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.RandomCrop((self.height, self.width)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + pixel_values = [] + for image in self.instance_images: + image = exif_transpose(image).convert("RGB") if image.mode != "RGB" else exif_transpose(image) + pixel_values.append(transform(image)) + return pixel_values + + def _map_embeddings(self): + df = pd.read_parquet(self.data_df_path) + data_dict = {} + for _, row in df.iterrows(): + prompt_embeds = torch.from_numpy(np.array(row["prompt_embeds"]).reshape(self.max_sequence_length, 4096)) + pooled_prompt_embeds = torch.from_numpy(np.array(row["pooled_prompt_embeds"]).reshape(768)) + text_ids = torch.from_numpy(np.array(row["text_ids"]).reshape(77, 3)) + data_dict[row["image_hash"]] = (prompt_embeds, pooled_prompt_embeds, text_ids) + return data_dict + +def collate_fn(examples): + pixel_values = torch.stack([ex["instance_images"] for ex in examples]).float() + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + prompt_embeds = torch.stack([ex["prompt_embeds"] for ex in examples]) + pooled_prompt_embeds = torch.stack([ex["pooled_prompt_embeds"] for ex in examples]) + text_ids = torch.stack([ex["text_ids"] for ex in examples])[0] + + return { + "pixel_values": pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + } + +def main(args): + # Setup accelerator + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=ProjectConfiguration(project_dir=args.output_dir, logging_dir=Path(args.output_dir, "logs")), + kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], + ) + + # Setup logging + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + set_seed(args.seed) if args.seed is not None else None + os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process else None + + # Load models with quantization + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + + nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", + quantization_config=nf4_config, torch_dtype=torch.float16 + ) + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + # Freeze models and setup LoRA + transformer.requires_grad_(False) + vae.requires_grad_(False) + vae.to(accelerator.device, dtype=torch.float16) + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + print(f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}") + + # Setup optimizer + import bitsandbytes as bnb + optimizer = bnb.optim.AdamW8bit( + [{"params": list(filter(lambda p: p.requires_grad, transformer.parameters())), "lr": args.learning_rate}], + betas=(0.9, 0.999), weight_decay=1e-04, eps=1e-08 + ) + + # Setup dataset and dataloader + train_dataset = DreamBoothDataset(args.data_df_path, "derekl35/alphonse-mucha-style", args.width, args.height) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) + + # Cache latents + vae_config = vae.config + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + pixel_values = batch["pixel_values"].to(accelerator.device, dtype=torch.float16) + latents_cache.append(vae.encode(pixel_values).latent_dist) + + del vae + free_memory() + + # Setup scheduler and training steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_steps or args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=args.max_train_steps) + + # Prepare for training + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler) + + # Register save/load hooks + def unwrap_model(model): + model = accelerator.unwrap_model(model) + return model._orig_mod if is_compiled_module(model) else model + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + lora_layers = get_peft_model_state_dict(unwrap_model(model)) + FluxPipeline.save_lora_weights(output_dir, transformer_lora_layers=lora_layers, text_encoder_lora_layers=None) + weights.pop() if weights else None + + accelerator.register_save_state_pre_hook(save_model_hook) + cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None + + # Initialize tracking + accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None + + # Training loop + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps.to(accelerator.device)] + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + global_step = 0 + progress_bar = tqdm(range(args.max_train_steps), desc="Steps", disable=not accelerator.is_local_main_process) + + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate([transformer]): + # Get cached latents + model_input = latents_cache[step].sample() + model_input = (model_input - vae_config.shift_factor) * vae_config.scaling_factor + model_input = model_input.to(dtype=torch.float16) + + # Prepare inputs + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], model_input.shape[2] // 2, model_input.shape[3] // 2, + accelerator.device, torch.float16 + ) + + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + u = compute_density_for_timestep_sampling("none", bsz, 0.0, 1.0, 1.29) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, model_input.shape[0], model_input.shape[1], + model_input.shape[2], model_input.shape[3] + ) + + # Forward pass + guidance = torch.tensor([args.guidance_scale], device=accelerator.device).expand(bsz) if unwrap_model(transformer).config.guidance_embeds else None + + model_pred = transformer( + hidden_states=packed_noisy_model_input, + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=batch["pooled_prompt_embeds"].to(accelerator.device, dtype=torch.float16), + encoder_hidden_states=batch["prompt_embeds"].to(accelerator.device, dtype=torch.float16), + txt_ids=batch["text_ids"].to(accelerator.device, dtype=torch.float16), + img_ids=latent_image_ids, + return_dict=False, + )[0] + + vae_scale_factor = 2 ** (len(vae_config.block_out_channels) - 1) + model_pred = FluxPipeline._unpack_latents( + model_pred, model_input.shape[2] * vae_scale_factor, + model_input.shape[3] * vae_scale_factor, vae_scale_factor + ) + + # Compute loss + weighting = compute_loss_weighting_for_sd3("none", sigmas) + target = noise - model_input + loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1).mean() + + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(transformer.parameters(), 1.0) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if global_step % args.checkpointing_steps == 0 and (accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED): + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + # Logging + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Final save + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer_lora_layers = get_peft_model_state_dict(unwrap_model(transformer)) + FluxPipeline.save_lora_weights(args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=None) + + if torch.cuda.is_available(): + print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") + else: + print("Training completed. GPU not available for memory tracking.") + + accelerator.end_training() + +if __name__ == "__main__": + class Args: + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + data_df_path = "embeddings_alphonse_mucha.parquet" + output_dir = "alphonse_mucha_lora_flux_nf4" + mixed_precision = "fp16" + weighting_scheme = "none" + width, height = 512, 768 + train_batch_size = 1 + learning_rate = 1e-4 + guidance_scale = 1.0 + report_to = "wandb" + gradient_accumulation_steps = 4 + gradient_checkpointing = True + rank = 4 + max_train_steps = 700 + seed = 0 + checkpointing_steps = 100 + + main(Args()) \ No newline at end of file From f5a0a4dcd53ec5770b86f90327e186d2f94fbfb5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 19 Jun 2025 02:02:09 +0000 Subject: [PATCH 2/7] Apply style fixes --- .../train_dreambooth_lora_flux_nano.py | 113 ++++++++++++------ 1 file changed, 79 insertions(+), 34 deletions(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index 963cfd69c3f4..9058505d70da 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -3,7 +3,6 @@ import math import os from pathlib import Path -import shutil import numpy as np import pandas as pd @@ -14,29 +13,34 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub.utils import insecure_hashlib -from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft import LoraConfig, prepare_model_for_kbit_training from peft.utils import get_peft_model_state_dict from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms -from torchvision.transforms.functional import crop from tqdm.auto import tqdm import diffusers from diffusers import ( - AutoencoderKL, BitsAndBytesConfig, FlowMatchEulerDiscreteScheduler, - FluxPipeline, FluxTransformer2DModel, + AutoencoderKL, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( - cast_training_params, compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, free_memory, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, ) -from diffusers.utils import convert_unet_state_dict_to_peft, is_wandb_available from diffusers.utils.torch_utils import is_compiled_module + logger = get_logger(__name__) + class DreamBoothDataset(Dataset): def __init__(self, data_df_path, dataset_name, width, height, max_sequence_length=77): self.width, self.height, self.max_sequence_length = width, height, max_sequence_length @@ -66,12 +70,14 @@ def __getitem__(self, index): } def _apply_transforms(self): - transform = transforms.Compose([ - transforms.Resize((self.height, self.width), interpolation=transforms.InterpolationMode.BILINEAR), - transforms.RandomCrop((self.height, self.width)), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + transform = transforms.Compose( + [ + transforms.Resize((self.height, self.width), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.RandomCrop((self.height, self.width)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) pixel_values = [] for image in self.instance_images: @@ -89,6 +95,7 @@ def _map_embeddings(self): data_dict[row["image_hash"]] = (prompt_embeds, pooled_prompt_embeds, text_ids) return data_dict + def collate_fn(examples): pixel_values = torch.stack([ex["instance_images"] for ex in examples]).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -103,6 +110,7 @@ def collate_fn(examples): "text_ids": text_ids, } + def main(args): # Setup accelerator accelerator = Accelerator( @@ -126,15 +134,19 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process else None # Load models with quantization - noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", - quantization_config=nf4_config, torch_dtype=torch.float16 + args.pretrained_model_name_or_path, + subfolder="transformer", + quantization_config=nf4_config, + torch_dtype=torch.float16, ) transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) @@ -154,13 +166,18 @@ def main(args): ) transformer.add_adapter(transformer_lora_config) - print(f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}") + print( + f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}" + ) # Setup optimizer import bitsandbytes as bnb + optimizer = bnb.optim.AdamW8bit( [{"params": list(filter(lambda p: p.requires_grad, transformer.parameters())), "lr": args.learning_rate}], - betas=(0.9, 0.999), weight_decay=1e-04, eps=1e-08 + betas=(0.9, 0.999), + weight_decay=1e-04, + eps=1e-08, ) # Setup dataset and dataloader @@ -186,10 +203,14 @@ def main(args): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=args.max_train_steps) + lr_scheduler = get_scheduler( + "constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=args.max_train_steps + ) # Prepare for training - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler) + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) # Register save/load hooks def unwrap_model(model): @@ -201,14 +222,18 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(unwrap_model(model), type(unwrap_model(transformer))): lora_layers = get_peft_model_state_dict(unwrap_model(model)) - FluxPipeline.save_lora_weights(output_dir, transformer_lora_layers=lora_layers, text_encoder_lora_layers=None) + FluxPipeline.save_lora_weights( + output_dir, transformer_lora_layers=lora_layers, text_encoder_lora_layers=None + ) weights.pop() if weights else None accelerator.register_save_state_pre_hook(save_model_hook) cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None # Initialize tracking - accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None + accelerator.init_trackers( + "dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args) + ) if accelerator.is_main_process else None # Training loop def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): @@ -233,8 +258,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Prepare inputs latent_image_ids = FluxPipeline._prepare_latent_image_ids( - model_input.shape[0], model_input.shape[2] // 2, model_input.shape[3] // 2, - accelerator.device, torch.float16 + model_input.shape[0], + model_input.shape[2] // 2, + model_input.shape[3] // 2, + accelerator.device, + torch.float16, ) noise = torch.randn_like(model_input) @@ -248,12 +276,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( - noisy_model_input, model_input.shape[0], model_input.shape[1], - model_input.shape[2], model_input.shape[3] + noisy_model_input, + model_input.shape[0], + model_input.shape[1], + model_input.shape[2], + model_input.shape[3], ) # Forward pass - guidance = torch.tensor([args.guidance_scale], device=accelerator.device).expand(bsz) if unwrap_model(transformer).config.guidance_embeds else None + guidance = ( + torch.tensor([args.guidance_scale], device=accelerator.device).expand(bsz) + if unwrap_model(transformer).config.guidance_embeds + else None + ) model_pred = transformer( hidden_states=packed_noisy_model_input, @@ -268,14 +303,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): vae_scale_factor = 2 ** (len(vae_config.block_out_channels) - 1) model_pred = FluxPipeline._unpack_latents( - model_pred, model_input.shape[2] * vae_scale_factor, - model_input.shape[3] * vae_scale_factor, vae_scale_factor + model_pred, + model_input.shape[2] * vae_scale_factor, + model_input.shape[3] * vae_scale_factor, + vae_scale_factor, ) # Compute loss weighting = compute_loss_weighting_for_sd3("none", sigmas) target = noise - model_input - loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1).mean() + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 + ).mean() accelerator.backward(loss) @@ -291,7 +330,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): global_step += 1 # Checkpointing - if global_step % args.checkpointing_steps == 0 and (accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED): + if global_step % args.checkpointing_steps == 0 and ( + accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED + ): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) @@ -307,7 +348,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer_lora_layers = get_peft_model_state_dict(unwrap_model(transformer)) - FluxPipeline.save_lora_weights(args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=None) + FluxPipeline.save_lora_weights( + args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=None + ) if torch.cuda.is_available(): print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") @@ -316,7 +359,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.end_training() + if __name__ == "__main__": + class Args: pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" data_df_path = "embeddings_alphonse_mucha.parquet" @@ -335,4 +380,4 @@ class Args: seed = 0 checkpointing_steps = 100 - main(Args()) \ No newline at end of file + main(Args()) From 5648aaac3f11f4212bb8b5ec742db489ce3b512f Mon Sep 17 00:00:00 2001 From: DerekLiu35 <91234588+DerekLiu35@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:42:29 +0200 Subject: [PATCH 3/7] Apply suggestions from code review Co-authored-by: Sayak Paul --- .../flux_lora_quantization/train_dreambooth_lora_flux_nano.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index 9058505d70da..f2c1d3a0b6fb 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -194,6 +194,7 @@ def main(args): pixel_values = batch["pixel_values"].to(accelerator.device, dtype=torch.float16) latents_cache.append(vae.encode(pixel_values).latent_dist) + vae.cpu() del vae free_memory() From ae6bd6190d749b8f2bcf86ac357b64a9f0b779e2 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 19 Jun 2025 11:06:51 +0200 Subject: [PATCH 4/7] apply suggestions from code review --- .../train_dreambooth_lora_flux_nano.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index f2c1d3a0b6fb..d3db13ed6e59 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + import copy import logging import math @@ -8,7 +23,7 @@ import pandas as pd import torch import transformers -from accelerate import Accelerator, DistributedType +from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset @@ -170,6 +185,12 @@ def main(args): f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}" ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + # Setup optimizer import bitsandbytes as bnb @@ -213,24 +234,10 @@ def main(args): transformer, optimizer, train_dataloader, lr_scheduler ) - # Register save/load hooks def unwrap_model(model): model = accelerator.unwrap_model(model) return model._orig_mod if is_compiled_module(model) else model - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - lora_layers = get_peft_model_state_dict(unwrap_model(model)) - FluxPipeline.save_lora_weights( - output_dir, transformer_lora_layers=lora_layers, text_encoder_lora_layers=None - ) - weights.pop() if weights else None - - accelerator.register_save_state_pre_hook(save_model_hook) - cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None - # Initialize tracking accelerator.init_trackers( "dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args) @@ -269,7 +276,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noise = torch.randn_like(model_input) bsz = model_input.shape[0] - u = compute_density_for_timestep_sampling("none", bsz, 0.0, 1.0, 1.29) + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) @@ -311,7 +324,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Compute loss - weighting = compute_loss_weighting_for_sd3("none", sigmas) + weighting = compute_loss_weighting_for_sd3(args.weighting_scheme, sigmas) target = noise - model_input loss = torch.mean( (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 @@ -330,13 +343,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - # Checkpointing - if global_step % args.checkpointing_steps == 0 and ( - accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED - ): - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - # Logging logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -365,10 +371,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): class Args: pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" - data_df_path = "embeddings_alphonse_mucha.parquet" + data_df_path = "embeddings_alphonse_mucha.parquet" # first, run compute_embeddings.py with a dataset like https://huggingface.co/datasets/derekl35/alphonse-mucha-style output_dir = "alphonse_mucha_lora_flux_nf4" mixed_precision = "fp16" - weighting_scheme = "none" + weighting_scheme = "none" # "sigma_sqrt", "logit_normal", "mode", "cosmap", "none" width, height = 512, 768 train_batch_size = 1 learning_rate = 1e-4 @@ -380,5 +386,8 @@ class Args: max_train_steps = 700 seed = 0 checkpointing_steps = 100 + logit_mean = 0.0 + logit_std = 1.0 + mode_scale = 1.29 main(Args()) From ba5144bb6344633c266a5c1852bc107c87db15af Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 19 Jun 2025 11:44:57 +0200 Subject: [PATCH 5/7] remove unused arg --- .../flux_lora_quantization/train_dreambooth_lora_flux_nano.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index d3db13ed6e59..be29b920dc69 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -385,7 +385,6 @@ class Args: rank = 4 max_train_steps = 700 seed = 0 - checkpointing_steps = 100 logit_mean = 0.0 logit_std = 1.0 mode_scale = 1.29 From 4a5f73a4d320190b7ab8f64201ef185430b5bb4f Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 19 Jun 2025 12:51:20 +0200 Subject: [PATCH 6/7] init tracker while creating the output folder --- .../train_dreambooth_lora_flux_nano.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index be29b920dc69..43c93626ea9b 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -146,8 +146,12 @@ def main(args): diffusers.utils.logging.set_verbosity_error() set_seed(args.seed) if args.seed is not None else None - os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process else None + if accelerator.is_main_proces: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) + # Load models with quantization noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" From dc7932e1049b172f46c237ef49f340973a32a1e2 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 19 Jun 2025 14:30:13 +0200 Subject: [PATCH 7/7] fix style --- .../flux_lora_quantization/train_dreambooth_lora_flux_nano.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py index 43c93626ea9b..f7e3d26dd8c0 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_nano.py @@ -151,7 +151,7 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) - + # Load models with quantization noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler"