Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Issue was coming from the definition of SpatialPad (self.ldm_resizer)… #449

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.data import decollate_batch
from monai.inferers import Inferer
from monai.transforms import CenterSpatialCrop, SpatialPad
from monai.utils import optional_import
Expand Down Expand Up @@ -348,8 +349,8 @@ def __init__(
self.ldm_latent_shape = ldm_latent_shape
self.autoencoder_latent_shape = autoencoder_latent_shape
if self.ldm_latent_shape is not None:
self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)

def __call__(
self,
Expand Down Expand Up @@ -379,7 +380,7 @@ def __call__(
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latent = self.ldm_resizer(latent)
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)

call = super().__call__
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
Expand Down Expand Up @@ -454,14 +455,15 @@ def sample(
else:
latent = outputs

if self.ldm_latent_shape is not None:
latent = self.autoencoder_resizer(latent)
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)

image = decode(latent / self.scale_factor)

if save_intermediates:
Expand Down Expand Up @@ -521,7 +523,7 @@ def get_likelihood(
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latents = self.ldm_resizer(latents)
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)

get_likelihood = super().get_likelihood
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
Expand Down Expand Up @@ -598,7 +600,7 @@ def __call__(

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg = seg)
diffuse = partial(diffusion_model, seg=seg)

prediction = diffuse(
x=noisy_image,
Expand Down Expand Up @@ -746,7 +748,7 @@ def get_likelihood(

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg = seg)
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat":
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
Expand Down Expand Up @@ -832,6 +834,7 @@ def get_likelihood(
else:
return total_kl


class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
"""
ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
Expand Down Expand Up @@ -861,7 +864,7 @@ def __init__(
self.ldm_latent_shape = ldm_latent_shape
self.autoencoder_latent_shape = autoencoder_latent_shape
if self.ldm_latent_shape is not None:
self.ldm_resizer = SpatialPad(spatial_size=[-1] + self.ldm_latent_shape)
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)

def __call__(
Expand Down Expand Up @@ -896,7 +899,8 @@ def __call__(
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latent = self.ldm_resizer(latent)
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)

if cn_cond.shape[2:] != latent.shape[2:]:
cn_cond = F.interpolate(cn_cond, latent.shape[2:])

Expand Down Expand Up @@ -985,9 +989,11 @@ def sample(
else:
latent = outputs

if self.ldm_latent_shape is not None:
latent = self.autoencoder_resizer(latent)
latent_intermediates = [self.autoencoder_resizer(l) for l in latent_intermediates]
if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down Expand Up @@ -1060,7 +1066,7 @@ def get_likelihood(
cn_cond = F.interpolate(cn_cond, latents.shape[2:])

if self.ldm_latent_shape is not None:
latents = self.ldm_resizer(latents)
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)

get_likelihood = super().get_likelihood
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
Expand All @@ -1085,6 +1091,7 @@ def get_likelihood(
outputs = (outputs[0], intermediates)
return outputs


class VQVAETransformerInferer(Inferer):
"""
Class to perform inference with a VQVAE + Transformer model.
Expand Down