Skip to content

Commit dc21498

Browse files
dg845patil-suraj
andauthored
Update LCMScheduler Inference Timesteps to be More Evenly Spaced (#5836)
* Change LCMScheduler.set_timesteps to pick more evenly spaced inference timesteps. * Change inference_indices implementation to better match previous behavior. * Add num_inference_steps=26 test case to test_inference_steps. * run CI --------- Co-authored-by: patil-suraj <surajp815@gmail.com>
1 parent 3303aec commit dc21498

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/diffusers/schedulers/scheduling_lcm.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,11 @@ def set_timesteps(
371371
)
372372

373373
# LCM Timesteps Setting
374-
# Currently, only linear spacing is supported.
375-
c = self.config.num_train_timesteps // original_steps
376-
# LCM Training Steps Schedule
377-
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
374+
# The skipping step parameter k from the paper.
375+
k = self.config.num_train_timesteps // original_steps
376+
# LCM Training/Distillation Steps Schedule
377+
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
378+
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
378379
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
379380

380381
if skipping_step < 1:
@@ -383,9 +384,13 @@ def set_timesteps(
383384
)
384385

385386
# LCM Inference Steps Schedule
386-
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
387+
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
388+
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
389+
inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
390+
inference_indices = np.floor(inference_indices).astype(np.int64)
391+
timesteps = lcm_origin_timesteps[inference_indices]
387392

388-
self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
393+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
389394

390395
self._step_index = None
391396

tests/schedulers/test_scheduler_lcm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_time_indices(self):
8484

8585
def test_inference_steps(self):
8686
# Hardcoded for now
87-
for t, num_inference_steps in zip([99, 39, 19], [10, 25, 50]):
87+
for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]):
8888
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
8989

9090
# Override test_add_noise_device because the hardcoded num_inference_steps of 100 doesn't work

0 commit comments

Comments
 (0)