-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Explanation of the 0.18215 factor in textual_inversion? #437
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Comments
That's the exact same value used in the original textual inversion code for the 'learning rate' setting. https://github.com/rinongal/textual_inversion/blob/main/configs/stable-diffusion/v1-finetune.yaml Going by wikipedia, it seems to be how much a weight value can shift on each batch iteration (I suspect the weights are 0 to 1 or -1 to 1), probably a scalar applied to the difference it currently has to the assumed ideal target weight, or something along those lines. |
Hey @garrett361 That comes from the original stable diffusion training.cf https://github.com/CompVis/stable-diffusion/blob/main/configs/stable-diffusion/v1-inference.yaml#L17 This is |
Hi @garrett361 @patil-suraj @CodeExplode We introduced the scale factor in the latent diffusion paper. The goal was to handle different latent spaces (from different autoencoders, which can be scaled quite differently than images) with similar noise schedules. The |
Perfect @rromb, yes, I was looking for the principle which led to one number versus another. (Sec. 4.3.2 and Appendices D.1 and G, for anyone looking.) To make sure I'm understanding, it sounds like you arrived at And if the above is right, I'm curious if you also tried instead whitening each latent individually, rather than using a single global scale for all latents? Or tried using |
@garrett361 Yes, your understanding is correct. We did not play much with other normalization schemes because the simple rescaling worked out of the box. |
Hypothetically if we were to retrain a latent diffusion model with more than one autoencoder, would you need a different scaling factor for each autoencoder to get approximately unit variance? |
In case this is useful for others, I've written some code to replicate the computation of that magic value. It seems to be a reasonable estimation! from diffusers import AutoencoderKL
import torch
import torchvision
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import transforms
num_workers = 4
batch_size = 12
# From https://github.com/fastai/imagenette
IMAGENETTE_URL = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
torch.manual_seed(0)
torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path,
subfolder='vae',
revision=None,
)
vae.to(device)
size = 512
image_transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
root = 'dataset'
download_and_extract_archive(IMAGENETTE_URL, root)
dataset = torchvision.datasets.ImageFolder(root, transform=image_transform)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
)
all_latents = []
for image_data, _ in loader:
image_data = image_data.to(device)
latents = vae.encode(image_data).latent_dist.sample()
all_latents.append(latents.cpu())
all_latents_tensor = torch.cat(all_latents)
std = all_latents_tensor.std().item()
normalizer = 1 / std
print(f'{normalizer = }') Output:
|
Hello, excuse me. I would like to ask about using the Celeba dataset for my autoencoder kl model that I trained myself .As I want to train 128*128 resolution autoencoderkl model and I am using scale_factor. Is it normal for scale to be approximately 0.44 when using factor? I still cannot achieve the Fid mentioned in the paper when training LDM with this autoencoderkl. |
From the perspective of latent variables, should we use |
It appears that: image = noise_scheduler.step(model_output, t, image, generator=None).prev_sample effectively clamps the This makes it essential that your VAE produces output in that range, since if it doesn't, then the decoder will be receiving LDM output that's in a different range than your VAE encoder's output. |
diffusers/examples/textual_inversion/textual_inversion.py
Line 501 in b2b3b1a
Hi, just a small question about the quoted script above which is bothering me: where does this
0.18215
number come from? What computation is being done? Is it from some paper? I have seen the same factor elsewhere, too, without explanation. Any guidance would be very helpful, thanks!The text was updated successfully, but these errors were encountered: