Skip to content

[research_projects] add shortened flux training script with quantization #11743

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

DerekLiu35
Copy link

adds shortened script to be referenced in this blogpost huggingface/blog#2888

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DerekLiu35
Copy link
Author

@sayakpaul

if __name__ == "__main__":
class Args:
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
data_df_path = "embeddings_alphonse_mucha.parquet"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note on where this is coming from.

)

# Compute loss
weighting = compute_loss_weighting_for_sd3("none", sigmas)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the weighting_scheme args. So, let's use it from there.

noise = torch.randn_like(model_input)
bsz = model_input.shape[0]

u = compute_density_for_timestep_sampling("none", bsz, 0.0, 1.0, 1.29)
Copy link
Member

@sayakpaul sayakpaul Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.weighting_scheme

Let's also make constants for the magic numbers.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments.

model = accelerator.unwrap_model(model)
return model._orig_mod if is_compiled_module(model) else model

def save_model_hook(models, weights, output_dir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't have a load model hook, then I don't think it will make sense to have this either no? Are do we have a utility to resume from a checkpoint in this script?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept save_model_hook to save intermediate checkpoints, but I probably didn't need to save optimizer states too. Though, yeah I think adding back load model hook to resume from checkpoints is a good idea

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah either we don't support intermediate checkpoints at all or support it. I think okay without to prefer minimalism.

cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None

# Initialize tracking
accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None
if accelerator.is_main_proces
accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args))

Can we do it while creating the output folder?

init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should cast the LoRA params to FP32. Do you have a full run with this script that works without FP32 upcasting?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was casting to FP32 below with
cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None below (probably will move it over here and change it to match original training script better.

I do have a full run with this script with reasonable results without FP32 upcasting.
But, I noticed in the loss curves are slightly different between nano script (rare-voice-24 run) and original script (fanciful-totem-2) so I will need to find where the discrepancy is coming from.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't affect results, probably okay

@sayakpaul
Copy link
Member

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants