Skip to content

Explanation of the 0.18215 factor in textual_inversion? #437

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
garrett361 opened this issue Sep 9, 2022 · 10 comments
Closed

Explanation of the 0.18215 factor in textual_inversion? #437

garrett361 opened this issue Sep 9, 2022 · 10 comments

Comments

@garrett361
Copy link

latents = latents * 0.18215

Hi, just a small question about the quoted script above which is bothering me: where does this 0.18215 number come from? What computation is being done? Is it from some paper? I have seen the same factor elsewhere, too, without explanation. Any guidance would be very helpful, thanks!

@CodeExplode
Copy link

CodeExplode commented Sep 9, 2022

That's the exact same value used in the original textual inversion code for the 'learning rate' setting. https://github.com/rinongal/textual_inversion/blob/main/configs/stable-diffusion/v1-finetune.yaml

Going by wikipedia, it seems to be how much a weight value can shift on each batch iteration (I suspect the weights are 0 to 1 or -1 to 1), probably a scalar applied to the difference it currently has to the assumed ideal target weight, or something along those lines.

@patil-suraj
Copy link
Contributor

Hey @garrett361

That comes from the original stable diffusion training.cf https://github.com/CompVis/stable-diffusion/blob/main/configs/stable-diffusion/v1-inference.yaml#L17

This is scale_factor which is used to scale the latents produced by the autoencoder before they are fed to the unet. Maybe @rromb can comment on why the scaling is necessary.

@rromb
Copy link

rromb commented Sep 9, 2022

Hi @garrett361 @patil-suraj @CodeExplode

We introduced the scale factor in the latent diffusion paper. The goal was to handle different latent spaces (from different autoencoders, which can be scaled quite differently than images) with similar noise schedules. The scale_factor ensures that the initial latent space on which the diffusion model is operating has approximately unit variance. Hope this helps :)

@garrett361
Copy link
Author

The scale_factor ensures that the initial latent space on which the diffusion model is operating has approximately unit variance. Hope this helps :)

Perfect @rromb, yes, I was looking for the principle which led to one number versus another. (Sec. 4.3.2 and Appendices D.1 and G, for anyone looking.)

To make sure I'm understanding, it sounds like you arrived at scale_factor = 0.18215 by averaging over a bunch of examples generated by the vae, in order to ensure they have unit variance with the variance taken over all dimensions simultaneously? And scale_factor = 1 / std(z), schematically?

And if the above is right, I'm curious if you also tried instead whitening each latent individually, rather than using a single global scale for all latents? Or tried using LayerNorm or similar?

@rromb
Copy link

rromb commented Sep 9, 2022

@garrett361 Yes, your understanding is correct. We did not play much with other normalization schemes because the simple rescaling worked out of the box.

@ezhang7423
Copy link

Hypothetically if we were to retrain a latent diffusion model with more than one autoencoder, would you need a different scaling factor for each autoencoder to get approximately unit variance?

@fepegar
Copy link

fepegar commented Dec 19, 2022

In case this is useful for others, I've written some code to replicate the computation of that magic value. It seems to be a reasonable estimation!

from diffusers import AutoencoderKL
import torch
import torchvision
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import transforms


num_workers = 4
batch_size = 12
# From https://github.com/fastai/imagenette
IMAGENETTE_URL = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'

torch.manual_seed(0)
torch.set_grad_enabled(False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path,
    subfolder='vae',
    revision=None,
)
vae.to(device)

size = 512
image_transform = transforms.Compose([
    transforms.Resize(size),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

root = 'dataset'
download_and_extract_archive(IMAGENETTE_URL, root)

dataset = torchvision.datasets.ImageFolder(root, transform=image_transform)
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

all_latents = []
for image_data, _ in loader:
    image_data = image_data.to(device)
    latents = vae.encode(image_data).latent_dist.sample()
    all_latents.append(latents.cpu())

all_latents_tensor = torch.cat(all_latents)
std = all_latents_tensor.std().item()
normalizer = 1 / std
print(f'{normalizer = }')

Output:

normalizer = 0.19503

@wj7486
Copy link

wj7486 commented Jan 6, 2024

Hi @garrett361 @patil-suraj @CodeExplode

We introduced the scale factor in the latent diffusion paper. The goal was to handle different latent spaces (from different autoencoders, which can be scaled quite differently than images) with similar noise schedules. The scale_factor ensures that the initial latent space on which the diffusion model is operating has approximately unit variance. Hope this helps :)

Hello, excuse me. I would like to ask about using the Celeba dataset for my autoencoder kl model that I trained myself .As I want to train 128*128 resolution autoencoderkl model and I am using scale_factor. Is it normal for scale to be approximately 0.44 when using factor? I still cannot achieve the Fid mentioned in the paper when training LDM with this autoencoderkl.
Looking forward to your reply, thank you

@guomc9
Copy link

guomc9 commented Jun 18, 2024

From the perspective of latent variables, should we use B to represent the number of samples for N (where N=H×W×C) latent variables? When calculating the standard deviation, we should standardize the N latent variables. Therefore, the observed mean and std calculated from these B samples should both have the shape [1,N]. Then, by normalizing the samples using $$\frac{samples−mean}{std}$$, can we better ensure the uniformity and fairness of the scales of all latent variables while finetuning the unet of LDM?

@jxtps
Copy link

jxtps commented Feb 28, 2025

It appears that:

image = noise_scheduler.step(model_output, t, image, generator=None).prev_sample

effectively clamps the image to be in the [-1, 1] range.

This makes it essential that your VAE produces output in that range, since if it doesn't, then the decoder will be receiving LDM output that's in a different range than your VAE encoder's output.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants