Skip to content
This repository has been archived by the owner on Feb 7, 2025. It is now read-only.

support convtranspose and activation checkpointing #415

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,38 @@ class Upsample(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels to the layer.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(self, spatial_dims: int, in_channels: int) -> None:
def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None:
super().__init__()
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
if use_convtranspose:
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=2,
kernel_size=3,
padding=1,
conv_only=True,
is_transposed=True,
)
else:
self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
self.use_convtranspose = use_convtranspose

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_convtranspose:
return self.conv(x)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# https://github.com/pytorch/pytorch/issues/86679
dtype = x.dtype
Expand Down Expand Up @@ -450,6 +467,7 @@ class Decoder(nn.Module):
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(
Expand All @@ -464,6 +482,7 @@ def __init__(
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
use_convtranspose: bool = False,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand Down Expand Up @@ -553,7 +572,9 @@ def __init__(
)

if not is_final_block:
blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch))
blocks.append(
Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose)
)

blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
blocks.append(
Expand Down Expand Up @@ -595,6 +616,8 @@ class AutoencoderKL(nn.Module):
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
use_checkpointing if True, use activation checkpointing to save memory.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
"""

def __init__(
Expand All @@ -611,6 +634,8 @@ def __init__(
with_encoder_nonlocal_attn: bool = True,
with_decoder_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
use_checkpointing: bool = False,
use_convtranspose: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -658,6 +683,7 @@ def __init__(
attention_levels=attention_levels,
with_nonlocal_attn=with_decoder_nonlocal_attn,
use_flash_attention=use_flash_attention,
use_convtranspose=use_convtranspose,
)
self.quant_conv_mu = Convolution(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -687,6 +713,7 @@ def __init__(
conv_only=True,
)
self.latent_channels = latent_channels
self.use_checkpointing = use_checkpointing

def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -696,7 +723,10 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x: BxCx[SPATIAL DIMS] tensor

"""
h = self.encoder(x)
if self.use_checkpointing:
h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
else:
h = self.encoder(x)

z_mu = self.quant_conv_mu(h)
z_log_var = self.quant_conv_log_sigma(h)
Expand Down Expand Up @@ -747,7 +777,10 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
decoded image tensor
"""
z = self.post_quant_conv(z)
dec = self.decoder(z)
if self.use_checkpointing:
dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
else:
dec = self.decoder(z)
return dec

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
):
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

# def test_script(self):
# input_param, input_shape, _, _ = CASES[0]
# net = AutoencoderKL(**input_param)
Expand Down Expand Up @@ -195,6 +207,14 @@ def test_shape_reconstruction(self):
result = net.reconstruct(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):
input_param, input_shape, expected_shape, _ = CASES[0]
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.reconstruct(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_shape_encode(self):
input_param, input_shape, _, expected_latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
Expand All @@ -203,6 +223,15 @@ def test_shape_encode(self):
self.assertEqual(result[0].shape, expected_latent_shape)
self.assertEqual(result[1].shape, expected_latent_shape)

def test_shape_encode_with_convtranspose_and_checkpointing(self):
input_param, input_shape, _, expected_latent_shape = CASES[0]
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.encode(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_latent_shape)
self.assertEqual(result[1].shape, expected_latent_shape)

def test_shape_sampling(self):
input_param, _, _, expected_latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
Expand All @@ -212,13 +241,31 @@ def test_shape_sampling(self):
)
self.assertEqual(result.shape, expected_latent_shape)

def test_shape_sampling_convtranspose_and_checkpointing(self):
input_param, _, _, expected_latent_shape = CASES[0]
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.sampling(
torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
)
self.assertEqual(result.shape, expected_latent_shape)

def test_shape_decode(self):
input_param, expected_input_shape, _, latent_shape = CASES[0]
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.decode(torch.randn(latent_shape).to(device))
self.assertEqual(result.shape, expected_input_shape)

def test_shape_decode_convtranspose_and_checkpointing(self):
input_param, expected_input_shape, _, latent_shape = CASES[0]
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKL(**input_param).to(device)
with eval_mode(net):
result = net.decode(torch.randn(latent_shape).to(device))
self.assertEqual(result.shape, expected_input_shape)


if __name__ == "__main__":
unittest.main()