Skip to content

Commit 481d76a

Browse files
committed
Add beta sigmas / beta noise schedule
1 parent 2b5bc5b commit 481d76a

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from typing import List, Optional, Tuple, Union
1818

1919
import numpy as np
20+
import scipy.stats
2021
import torch
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import BaseOutput, logging
24+
from ..utils import BaseOutput, is_scipy_available, logging
2425
from ..utils.torch_utils import randn_tensor
2526
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2627

@@ -160,6 +161,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
160161
the sigmas are determined according to a sequence of noise levels {σi}.
161162
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162163
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
164+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
165+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to
166+
[Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
163167
timestep_spacing (`str`, defaults to `"linspace"`):
164168
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165169
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -189,6 +193,7 @@ def __init__(
189193
interpolation_type: str = "linear",
190194
use_karras_sigmas: Optional[bool] = False,
191195
use_exponential_sigmas: Optional[bool] = False,
196+
use_beta_sigmas: Optional[bool] = False,
192197
sigma_min: Optional[float] = None,
193198
sigma_max: Optional[float] = None,
194199
timestep_spacing: str = "linspace",
@@ -197,6 +202,12 @@ def __init__(
197202
rescale_betas_zero_snr: bool = False,
198203
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
199204
):
205+
if self.config.use_beta_sigmas and not is_scipy_available():
206+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
207+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
208+
raise ValueError(
209+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
210+
)
200211
if trained_betas is not None:
201212
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
202213
elif beta_schedule == "linear":
@@ -239,6 +250,7 @@ def __init__(
239250
self.is_scale_input_called = False
240251
self.use_karras_sigmas = use_karras_sigmas
241252
self.use_exponential_sigmas = use_exponential_sigmas
253+
self.use_beta_sigmas = use_beta_sigmas
242254

243255
self._step_index = None
244256
self._begin_index = None
@@ -338,10 +350,8 @@ def set_timesteps(
338350
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
339351
if timesteps is not None and self.config.use_exponential_sigmas:
340352
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
341-
if self.config.use_exponential_sigmas and self.config.use_karras_sigmas:
342-
raise ValueError(
343-
"Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
344-
)
353+
if timesteps is not None and self.config.use_beta_sigmas:
354+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
345355
if (
346356
timesteps is not None
347357
and self.config.timestep_type == "continuous"
@@ -410,6 +420,10 @@ def set_timesteps(
410420
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
411421
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
412422

423+
elif self.config.use_beta_sigmas:
424+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
425+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
426+
413427
if self.config.final_sigmas_type == "sigma_min":
414428
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
415429
elif self.config.final_sigmas_type == "zero":
@@ -504,6 +518,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
504518
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
505519
return sigmas
506520

521+
def _convert_to_beta(
522+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
523+
) -> torch.Tensor:
524+
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
525+
526+
# Hack to make sure that other schedulers which copy this function don't break
527+
# TODO: Add this logic to the other schedulers
528+
if hasattr(self.config, "sigma_min"):
529+
sigma_min = self.config.sigma_min
530+
else:
531+
sigma_min = None
532+
533+
if hasattr(self.config, "sigma_max"):
534+
sigma_max = self.config.sigma_max
535+
else:
536+
sigma_max = None
537+
538+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
539+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
540+
541+
sigmas = torch.FloatTensor(
542+
[
543+
sigma_min + (ppf * (sigma_max - sigma_min))
544+
for ppf in [
545+
scipy.stats.beta.ppf(timestep, alpha, beta)
546+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
547+
]
548+
]
549+
)
550+
return sigmas
551+
507552
def index_for_timestep(self, timestep, schedule_timesteps=None):
508553
if schedule_timesteps is None:
509554
schedule_timesteps = self.timesteps

0 commit comments

Comments
 (0)