From edc1619987933320ea156dd6a3f1d4abf3e18ab6 Mon Sep 17 00:00:00 2001
From: hlky <hlky@hlky.ac>
Date: Tue, 24 Sep 2024 01:42:23 +0100
Subject: [PATCH] Add beta sigmas / beta noise schedule

---
 .../schedulers/scheduling_euler_discrete.py   | 55 ++++++++++++++++++-
 1 file changed, 52 insertions(+), 3 deletions(-)

diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index e79dbe3fe8ab..5c39583356ad 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -20,11 +20,14 @@
 import torch
 
 from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, logging
+from ..utils import BaseOutput, is_scipy_available, logging
 from ..utils.torch_utils import randn_tensor
 from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
 
 
+if is_scipy_available():
+    import scipy.stats
+
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 
 
@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
             the sigmas are determined according to a sequence of noise levels {σi}.
         use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
             Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
+        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
+            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
+            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
         timestep_spacing (`str`, defaults to `"linspace"`):
             The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
             Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -189,6 +195,7 @@ def __init__(
         interpolation_type: str = "linear",
         use_karras_sigmas: Optional[bool] = False,
         use_exponential_sigmas: Optional[bool] = False,
+        use_beta_sigmas: Optional[bool] = False,
         sigma_min: Optional[float] = None,
         sigma_max: Optional[float] = None,
         timestep_spacing: str = "linspace",
@@ -197,8 +204,12 @@ def __init__(
         rescale_betas_zero_snr: bool = False,
         final_sigmas_type: str = "zero",  # can be "zero" or "sigma_min"
     ):
-        if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
-            raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
+        if self.config.use_beta_sigmas and not is_scipy_available():
+            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
+        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
+            raise ValueError(
+                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
+            )
         if trained_betas is not None:
             self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         elif beta_schedule == "linear":
@@ -241,6 +252,7 @@ def __init__(
         self.is_scale_input_called = False
         self.use_karras_sigmas = use_karras_sigmas
         self.use_exponential_sigmas = use_exponential_sigmas
+        self.use_beta_sigmas = use_beta_sigmas
 
         self._step_index = None
         self._begin_index = None
@@ -340,6 +352,8 @@ def set_timesteps(
             raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
         if timesteps is not None and self.config.use_exponential_sigmas:
             raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
+        if timesteps is not None and self.config.use_beta_sigmas:
+            raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
         if (
             timesteps is not None
             and self.config.timestep_type == "continuous"
@@ -408,6 +422,10 @@ def set_timesteps(
                 sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
                 timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
 
+            elif self.config.use_beta_sigmas:
+                sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+                timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
             if self.config.final_sigmas_type == "sigma_min":
                 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
             elif self.config.final_sigmas_type == "zero":
@@ -502,6 +520,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
         sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
         return sigmas
 
+    def _convert_to_beta(
+        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
+    ) -> torch.Tensor:
+        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+
+        # Hack to make sure that other schedulers which copy this function don't break
+        # TODO: Add this logic to the other schedulers
+        if hasattr(self.config, "sigma_min"):
+            sigma_min = self.config.sigma_min
+        else:
+            sigma_min = None
+
+        if hasattr(self.config, "sigma_max"):
+            sigma_max = self.config.sigma_max
+        else:
+            sigma_max = None
+
+        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+        sigmas = torch.Tensor(
+            [
+                sigma_min + (ppf * (sigma_max - sigma_min))
+                for ppf in [
+                    scipy.stats.beta.ppf(timestep, alpha, beta)
+                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
+                ]
+            ]
+        )
+        return sigmas
+
     def index_for_timestep(self, timestep, schedule_timesteps=None):
         if schedule_timesteps is None:
             schedule_timesteps = self.timesteps