Skip to content

Commit edc1619

Browse files
committed
Add beta sigmas / beta noise schedule
1 parent aa3c46d commit edc1619

File tree

1 file changed

+52
-3
lines changed

1 file changed

+52
-3
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
import torch
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import BaseOutput, logging
23+
from ..utils import BaseOutput, is_scipy_available, logging
2424
from ..utils.torch_utils import randn_tensor
2525
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2626

2727

28+
if is_scipy_available():
29+
import scipy.stats
30+
2831
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2932

3033

@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
160163
the sigmas are determined according to a sequence of noise levels {σi}.
161164
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162165
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
166+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
167+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
168+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
163169
timestep_spacing (`str`, defaults to `"linspace"`):
164170
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165171
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -189,6 +195,7 @@ def __init__(
189195
interpolation_type: str = "linear",
190196
use_karras_sigmas: Optional[bool] = False,
191197
use_exponential_sigmas: Optional[bool] = False,
198+
use_beta_sigmas: Optional[bool] = False,
192199
sigma_min: Optional[float] = None,
193200
sigma_max: Optional[float] = None,
194201
timestep_spacing: str = "linspace",
@@ -197,8 +204,12 @@ def __init__(
197204
rescale_betas_zero_snr: bool = False,
198205
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
199206
):
200-
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
201-
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
207+
if self.config.use_beta_sigmas and not is_scipy_available():
208+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
209+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
210+
raise ValueError(
211+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
212+
)
202213
if trained_betas is not None:
203214
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
204215
elif beta_schedule == "linear":
@@ -241,6 +252,7 @@ def __init__(
241252
self.is_scale_input_called = False
242253
self.use_karras_sigmas = use_karras_sigmas
243254
self.use_exponential_sigmas = use_exponential_sigmas
255+
self.use_beta_sigmas = use_beta_sigmas
244256

245257
self._step_index = None
246258
self._begin_index = None
@@ -340,6 +352,8 @@ def set_timesteps(
340352
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
341353
if timesteps is not None and self.config.use_exponential_sigmas:
342354
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
355+
if timesteps is not None and self.config.use_beta_sigmas:
356+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
343357
if (
344358
timesteps is not None
345359
and self.config.timestep_type == "continuous"
@@ -408,6 +422,10 @@ def set_timesteps(
408422
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
409423
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
410424

425+
elif self.config.use_beta_sigmas:
426+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
427+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428+
411429
if self.config.final_sigmas_type == "sigma_min":
412430
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
413431
elif self.config.final_sigmas_type == "zero":
@@ -502,6 +520,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
502520
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
503521
return sigmas
504522

523+
def _convert_to_beta(
524+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525+
) -> torch.Tensor:
526+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527+
528+
# Hack to make sure that other schedulers which copy this function don't break
529+
# TODO: Add this logic to the other schedulers
530+
if hasattr(self.config, "sigma_min"):
531+
sigma_min = self.config.sigma_min
532+
else:
533+
sigma_min = None
534+
535+
if hasattr(self.config, "sigma_max"):
536+
sigma_max = self.config.sigma_max
537+
else:
538+
sigma_max = None
539+
540+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542+
543+
sigmas = torch.Tensor(
544+
[
545+
sigma_min + (ppf * (sigma_max - sigma_min))
546+
for ppf in [
547+
scipy.stats.beta.ppf(timestep, alpha, beta)
548+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549+
]
550+
]
551+
)
552+
return sigmas
553+
505554
def index_for_timestep(self, timestep, schedule_timesteps=None):
506555
if schedule_timesteps is None:
507556
schedule_timesteps = self.timesteps

0 commit comments

Comments
 (0)