-
Notifications
You must be signed in to change notification settings - Fork 96
Conversation
…olNetDiffusionInferer to clean up the code. Ran the notebook.
# Without using an inferer: | ||
# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110) | ||
# progress_bar_sampling.set_description("sampling...") | ||
# sample = torch.randn((1, 1, 64, 64)).to(device) | ||
# for t in progress_bar_sampling: | ||
# with torch.no_grad(): | ||
# with autocast(enabled=True): | ||
# down_block_res_samples, mid_block_res_sample = controlnet( | ||
# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...] | ||
# ) | ||
# noise_pred = model( | ||
# sample, | ||
# timesteps=torch.Tensor((t,)).to(device), | ||
# down_block_additional_residuals=down_block_res_samples, | ||
# mid_block_additional_residual=mid_block_res_sample, | ||
# ) | ||
# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to leave this here or just point people to the definition of the inferer's sample
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just leaving them there in case a user did not want to use the inferer, as it was previously like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its OK to leave it there, it might help users get a bit more insight into what the inferer is actually doing under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
I fixed some typos in the Controlnet tutorial, and included new ControlNetDiffusionInferer functionality to make the code cleaner. The notebook's been ran again to account for modifications.
In response to: #443