20
20
import torch
21
21
22
22
from ..configuration_utils import ConfigMixin , register_to_config
23
- from ..utils import BaseOutput , logging
23
+ from ..utils import BaseOutput , is_scipy_available , logging
24
24
from ..utils .torch_utils import randn_tensor
25
25
from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
26
26
27
27
28
+ if is_scipy_available ():
29
+ import scipy .stats
30
+
28
31
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
29
32
30
33
@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
160
163
the sigmas are determined according to a sequence of noise levels {σi}.
161
164
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162
165
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
166
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
167
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
168
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
163
169
timestep_spacing (`str`, defaults to `"linspace"`):
164
170
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165
171
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -189,6 +195,7 @@ def __init__(
189
195
interpolation_type : str = "linear" ,
190
196
use_karras_sigmas : Optional [bool ] = False ,
191
197
use_exponential_sigmas : Optional [bool ] = False ,
198
+ use_beta_sigmas : Optional [bool ] = False ,
192
199
sigma_min : Optional [float ] = None ,
193
200
sigma_max : Optional [float ] = None ,
194
201
timestep_spacing : str = "linspace" ,
@@ -197,8 +204,12 @@ def __init__(
197
204
rescale_betas_zero_snr : bool = False ,
198
205
final_sigmas_type : str = "zero" , # can be "zero" or "sigma_min"
199
206
):
200
- if sum ([self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
201
- raise ValueError ("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." )
207
+ if self .config .use_beta_sigmas and not is_scipy_available ():
208
+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
209
+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
210
+ raise ValueError (
211
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
212
+ )
202
213
if trained_betas is not None :
203
214
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
204
215
elif beta_schedule == "linear" :
@@ -241,6 +252,7 @@ def __init__(
241
252
self .is_scale_input_called = False
242
253
self .use_karras_sigmas = use_karras_sigmas
243
254
self .use_exponential_sigmas = use_exponential_sigmas
255
+ self .use_beta_sigmas = use_beta_sigmas
244
256
245
257
self ._step_index = None
246
258
self ._begin_index = None
@@ -340,6 +352,8 @@ def set_timesteps(
340
352
raise ValueError ("Cannot set `timesteps` with `config.use_karras_sigmas = True`." )
341
353
if timesteps is not None and self .config .use_exponential_sigmas :
342
354
raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
355
+ if timesteps is not None and self .config .use_beta_sigmas :
356
+ raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
343
357
if (
344
358
timesteps is not None
345
359
and self .config .timestep_type == "continuous"
@@ -408,6 +422,10 @@ def set_timesteps(
408
422
sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
409
423
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
410
424
425
+ elif self .config .use_beta_sigmas :
426
+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
427
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
428
+
411
429
if self .config .final_sigmas_type == "sigma_min" :
412
430
sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
413
431
elif self .config .final_sigmas_type == "zero" :
@@ -502,6 +520,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
502
520
sigmas = torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
503
521
return sigmas
504
522
523
+ def _convert_to_beta (
524
+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
525
+ ) -> torch .Tensor :
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr (self .config , "sigma_min" ):
531
+ sigma_min = self .config .sigma_min
532
+ else :
533
+ sigma_min = None
534
+
535
+ if hasattr (self .config , "sigma_max" ):
536
+ sigma_max = self .config .sigma_max
537
+ else :
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
542
+
543
+ sigmas = torch .Tensor (
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min ))
546
+ for ppf in [
547
+ scipy .stats .beta .ppf (timestep , alpha , beta )
548
+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
505
554
def index_for_timestep (self , timestep , schedule_timesteps = None ):
506
555
if schedule_timesteps is None :
507
556
schedule_timesteps = self .timesteps
0 commit comments