18
18
19
19
from ...configuration_utils import ConfigMixin , register_to_config
20
20
from ...loaders .single_file_model import FromOriginalModelMixin
21
+ from ...utils import deprecate
21
22
from ...utils .accelerate_utils import apply_forward_hook
22
23
from ..attention_processor import (
23
24
ADDED_KV_ATTENTION_PROCESSORS ,
@@ -245,6 +246,18 @@ def set_default_attn_processor(self):
245
246
246
247
self .set_attn_processor (processor )
247
248
249
+ def _encode (self , x : torch .Tensor ) -> torch .Tensor :
250
+ batch_size , num_channels , height , width = x .shape
251
+
252
+ if self .use_tiling and (width > self .tile_sample_min_size or height > self .tile_sample_min_size ):
253
+ return self ._tiled_encode (x )
254
+
255
+ enc = self .encoder (x )
256
+ if self .quant_conv is not None :
257
+ enc = self .quant_conv (enc )
258
+
259
+ return enc
260
+
248
261
@apply_forward_hook
249
262
def encode (
250
263
self , x : torch .Tensor , return_dict : bool = True
@@ -261,21 +274,13 @@ def encode(
261
274
The latent representations of the encoded images. If `return_dict` is True, a
262
275
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
263
276
"""
264
- if self .use_tiling and (x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size ):
265
- return self .tiled_encode (x , return_dict = return_dict )
266
-
267
277
if self .use_slicing and x .shape [0 ] > 1 :
268
- encoded_slices = [self .encoder (x_slice ) for x_slice in x .split (1 )]
278
+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
269
279
h = torch .cat (encoded_slices )
270
280
else :
271
- h = self .encoder (x )
272
-
273
- if self .quant_conv is not None :
274
- moments = self .quant_conv (h )
275
- else :
276
- moments = h
281
+ h = self ._encode (x )
277
282
278
- posterior = DiagonalGaussianDistribution (moments )
283
+ posterior = DiagonalGaussianDistribution (h )
279
284
280
285
if not return_dict :
281
286
return (posterior ,)
@@ -337,6 +342,54 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
337
342
b [:, :, :, x ] = a [:, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, x ] * (x / blend_extent )
338
343
return b
339
344
345
+ def _tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
346
+ r"""Encode a batch of images using a tiled encoder.
347
+
348
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
349
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
350
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
351
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
352
+ output, but they should be much less noticeable.
353
+
354
+ Args:
355
+ x (`torch.Tensor`): Input batch of images.
356
+
357
+ Returns:
358
+ `torch.Tensor`:
359
+ The latent representation of the encoded videos.
360
+ """
361
+
362
+ overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
363
+ blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
364
+ row_limit = self .tile_latent_min_size - blend_extent
365
+
366
+ # Split the image into 512x512 tiles and encode them separately.
367
+ rows = []
368
+ for i in range (0 , x .shape [2 ], overlap_size ):
369
+ row = []
370
+ for j in range (0 , x .shape [3 ], overlap_size ):
371
+ tile = x [:, :, i : i + self .tile_sample_min_size , j : j + self .tile_sample_min_size ]
372
+ tile = self .encoder (tile )
373
+ if self .config .use_quant_conv :
374
+ tile = self .quant_conv (tile )
375
+ row .append (tile )
376
+ rows .append (row )
377
+ result_rows = []
378
+ for i , row in enumerate (rows ):
379
+ result_row = []
380
+ for j , tile in enumerate (row ):
381
+ # blend the above tile and the left tile
382
+ # to the current tile and add the current tile to the result row
383
+ if i > 0 :
384
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
385
+ if j > 0 :
386
+ tile = self .blend_h (row [j - 1 ], tile , blend_extent )
387
+ result_row .append (tile [:, :, :row_limit , :row_limit ])
388
+ result_rows .append (torch .cat (result_row , dim = 3 ))
389
+
390
+ enc = torch .cat (result_rows , dim = 2 )
391
+ return enc
392
+
340
393
def tiled_encode (self , x : torch .Tensor , return_dict : bool = True ) -> AutoencoderKLOutput :
341
394
r"""Encode a batch of images using a tiled encoder.
342
395
@@ -356,6 +409,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
356
409
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
357
410
`tuple` is returned.
358
411
"""
412
+ deprecation_message = (
413
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
414
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
415
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
416
+ )
417
+ deprecate ("tiled_encode" , "1.0.0" , deprecation_message , standard_warn = False )
418
+
359
419
overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
360
420
blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
361
421
row_limit = self .tile_latent_min_size - blend_extent
0 commit comments