17
17
from typing import List , Optional , Tuple , Union
18
18
19
19
import numpy as np
20
+ import scipy .stats
20
21
import torch
21
22
22
23
from ..configuration_utils import ConfigMixin , register_to_config
23
- from ..utils import BaseOutput , logging
24
+ from ..utils import BaseOutput , is_scipy_available , logging
24
25
from ..utils .torch_utils import randn_tensor
25
26
from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
26
27
@@ -160,6 +161,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
160
161
the sigmas are determined according to a sequence of noise levels {σi}.
161
162
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162
163
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
164
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
165
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to
166
+ [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
163
167
timestep_spacing (`str`, defaults to `"linspace"`):
164
168
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165
169
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -189,6 +193,7 @@ def __init__(
189
193
interpolation_type : str = "linear" ,
190
194
use_karras_sigmas : Optional [bool ] = False ,
191
195
use_exponential_sigmas : Optional [bool ] = False ,
196
+ use_beta_sigmas : Optional [bool ] = False ,
192
197
sigma_min : Optional [float ] = None ,
193
198
sigma_max : Optional [float ] = None ,
194
199
timestep_spacing : str = "linspace" ,
@@ -197,6 +202,12 @@ def __init__(
197
202
rescale_betas_zero_snr : bool = False ,
198
203
final_sigmas_type : str = "zero" , # can be "zero" or "sigma_min"
199
204
):
205
+ if self .config .use_beta_sigmas and not is_scipy_available ():
206
+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
207
+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
208
+ raise ValueError (
209
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
210
+ )
200
211
if trained_betas is not None :
201
212
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
202
213
elif beta_schedule == "linear" :
@@ -239,6 +250,7 @@ def __init__(
239
250
self .is_scale_input_called = False
240
251
self .use_karras_sigmas = use_karras_sigmas
241
252
self .use_exponential_sigmas = use_exponential_sigmas
253
+ self .use_beta_sigmas = use_beta_sigmas
242
254
243
255
self ._step_index = None
244
256
self ._begin_index = None
@@ -338,10 +350,8 @@ def set_timesteps(
338
350
raise ValueError ("Cannot set `timesteps` with `config.use_karras_sigmas = True`." )
339
351
if timesteps is not None and self .config .use_exponential_sigmas :
340
352
raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
341
- if self .config .use_exponential_sigmas and self .config .use_karras_sigmas :
342
- raise ValueError (
343
- "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
344
- )
353
+ if timesteps is not None and self .config .use_beta_sigmas :
354
+ raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
345
355
if (
346
356
timesteps is not None
347
357
and self .config .timestep_type == "continuous"
@@ -410,6 +420,10 @@ def set_timesteps(
410
420
sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
411
421
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
412
422
423
+ elif self .config .use_beta_sigmas :
424
+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
425
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
426
+
413
427
if self .config .final_sigmas_type == "sigma_min" :
414
428
sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
415
429
elif self .config .final_sigmas_type == "zero" :
@@ -504,6 +518,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
504
518
sigmas = torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
505
519
return sigmas
506
520
521
+ def _convert_to_beta (
522
+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
523
+ ) -> torch .Tensor :
524
+ # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
525
+
526
+ # Hack to make sure that other schedulers which copy this function don't break
527
+ # TODO: Add this logic to the other schedulers
528
+ if hasattr (self .config , "sigma_min" ):
529
+ sigma_min = self .config .sigma_min
530
+ else :
531
+ sigma_min = None
532
+
533
+ if hasattr (self .config , "sigma_max" ):
534
+ sigma_max = self .config .sigma_max
535
+ else :
536
+ sigma_max = None
537
+
538
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
539
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
540
+
541
+ sigmas = torch .FloatTensor (
542
+ [
543
+ sigma_min + (ppf * (sigma_max - sigma_min ))
544
+ for ppf in [
545
+ scipy .stats .beta .ppf (timestep , alpha , beta )
546
+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
547
+ ]
548
+ ]
549
+ )
550
+ return sigmas
551
+
507
552
def index_for_timestep (self , timestep , schedule_timesteps = None ):
508
553
if schedule_timesteps is None :
509
554
schedule_timesteps = self .timesteps
0 commit comments