diff --git a/cm/karras_diffusion.py b/cm/karras_diffusion.py index 87b0e18..34ecae1 100644 --- a/cm/karras_diffusion.py +++ b/cm/karras_diffusion.py @@ -187,7 +187,7 @@ def euler_solver(samples, t, next_t, x0): x_t = x_start + noise * append_dims(t, dims) - dropout_state = th.get_rng_state() + dropout_state = (th.get_rng_state(), th.cuda.get_rng_state()) distiller = denoise_fn(x_t, t) if teacher_model is None: @@ -195,7 +195,8 @@ def euler_solver(samples, t, next_t, x0): else: x_t2 = heun_solver(x_t, t, t2, x_start).detach() - th.set_rng_state(dropout_state) + th.set_rng_state(dropout_state[0]) + th.cuda.set_rng_state(dropout_state[1]) distiller_target = target_denoise_fn(x_t2, t2) distiller_target = distiller_target.detach()