Skip to content

Commit f018acd

Browse files
a-r-r-o-wsayakpaulyiyixuxu
committedDec 23, 2024
[bug] Precedence of operations in VAE should be slicing -> tiling (#9342)
* bugfix: precedence of operations should be slicing -> tiling * fix typo * fix another typo * deprecate current implementation of tiled_encode and use new impl * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 5cef1c5 commit f018acd

File tree

1 file changed

+71
-11
lines changed

1 file changed

+71
-11
lines changed
 

‎src/diffusers/models/autoencoders/autoencoder_kl.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...loaders.single_file_model import FromOriginalModelMixin
21+
from ...utils import deprecate
2122
from ...utils.accelerate_utils import apply_forward_hook
2223
from ..attention_processor import (
2324
ADDED_KV_ATTENTION_PROCESSORS,
@@ -245,6 +246,18 @@ def set_default_attn_processor(self):
245246

246247
self.set_attn_processor(processor)
247248

249+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
250+
batch_size, num_channels, height, width = x.shape
251+
252+
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
253+
return self._tiled_encode(x)
254+
255+
enc = self.encoder(x)
256+
if self.quant_conv is not None:
257+
enc = self.quant_conv(enc)
258+
259+
return enc
260+
248261
@apply_forward_hook
249262
def encode(
250263
self, x: torch.Tensor, return_dict: bool = True
@@ -261,21 +274,13 @@ def encode(
261274
The latent representations of the encoded images. If `return_dict` is True, a
262275
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
263276
"""
264-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
265-
return self.tiled_encode(x, return_dict=return_dict)
266-
267277
if self.use_slicing and x.shape[0] > 1:
268-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
278+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
269279
h = torch.cat(encoded_slices)
270280
else:
271-
h = self.encoder(x)
272-
273-
if self.quant_conv is not None:
274-
moments = self.quant_conv(h)
275-
else:
276-
moments = h
281+
h = self._encode(x)
277282

278-
posterior = DiagonalGaussianDistribution(moments)
283+
posterior = DiagonalGaussianDistribution(h)
279284

280285
if not return_dict:
281286
return (posterior,)
@@ -337,6 +342,54 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
337342
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
338343
return b
339344

345+
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
346+
r"""Encode a batch of images using a tiled encoder.
347+
348+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
349+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
350+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
351+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
352+
output, but they should be much less noticeable.
353+
354+
Args:
355+
x (`torch.Tensor`): Input batch of images.
356+
357+
Returns:
358+
`torch.Tensor`:
359+
The latent representation of the encoded videos.
360+
"""
361+
362+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
363+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
364+
row_limit = self.tile_latent_min_size - blend_extent
365+
366+
# Split the image into 512x512 tiles and encode them separately.
367+
rows = []
368+
for i in range(0, x.shape[2], overlap_size):
369+
row = []
370+
for j in range(0, x.shape[3], overlap_size):
371+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
372+
tile = self.encoder(tile)
373+
if self.config.use_quant_conv:
374+
tile = self.quant_conv(tile)
375+
row.append(tile)
376+
rows.append(row)
377+
result_rows = []
378+
for i, row in enumerate(rows):
379+
result_row = []
380+
for j, tile in enumerate(row):
381+
# blend the above tile and the left tile
382+
# to the current tile and add the current tile to the result row
383+
if i > 0:
384+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
385+
if j > 0:
386+
tile = self.blend_h(row[j - 1], tile, blend_extent)
387+
result_row.append(tile[:, :, :row_limit, :row_limit])
388+
result_rows.append(torch.cat(result_row, dim=3))
389+
390+
enc = torch.cat(result_rows, dim=2)
391+
return enc
392+
340393
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
341394
r"""Encode a batch of images using a tiled encoder.
342395
@@ -356,6 +409,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
356409
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
357410
`tuple` is returned.
358411
"""
412+
deprecation_message = (
413+
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
414+
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
415+
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
416+
)
417+
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
418+
359419
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
360420
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
361421
row_limit = self.tile_latent_min_size - blend_extent

0 commit comments

Comments
 (0)