@@ -207,6 +207,9 @@ def __init__(
207
207
self .vae_scale_factor_temporal = (
208
208
self .vae .config .temporal_compression_ratio if hasattr (self , "vae" ) and self .vae is not None else 4
209
209
)
210
+ self .vae_scaling_factor_image = (
211
+ self .vae .config .scaling_factor if hasattr (self , "vae" ) and self .vae is not None else 0.7
212
+ )
210
213
211
214
self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
212
215
@@ -348,6 +351,12 @@ def prepare_latents(
348
351
generator : Optional [torch .Generator ] = None ,
349
352
latents : Optional [torch .Tensor ] = None ,
350
353
):
354
+ if isinstance (generator , list ) and len (generator ) != batch_size :
355
+ raise ValueError (
356
+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
357
+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
358
+ )
359
+
351
360
num_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
352
361
shape = (
353
362
batch_size ,
@@ -357,12 +366,6 @@ def prepare_latents(
357
366
width // self .vae_scale_factor_spatial ,
358
367
)
359
368
360
- if isinstance (generator , list ) and len (generator ) != batch_size :
361
- raise ValueError (
362
- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
363
- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
364
- )
365
-
366
369
image = image .unsqueeze (2 ) # [B, C, F, H, W]
367
370
368
371
if isinstance (generator , list ):
@@ -373,7 +376,7 @@ def prepare_latents(
373
376
image_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
374
377
375
378
image_latents = torch .cat (image_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 ) # [B, F, C, H, W]
376
- image_latents = self .vae . config . scaling_factor * image_latents
379
+ image_latents = self .vae_scaling_factor_image * image_latents
377
380
378
381
padding_shape = (
379
382
batch_size ,
@@ -397,7 +400,7 @@ def prepare_latents(
397
400
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
398
401
def decode_latents (self , latents : torch .Tensor ) -> torch .Tensor :
399
402
latents = latents .permute (0 , 2 , 1 , 3 , 4 ) # [batch_size, num_channels, num_frames, height, width]
400
- latents = 1 / self .vae . config . scaling_factor * latents
403
+ latents = 1 / self .vae_scaling_factor_image * latents
401
404
402
405
frames = self .vae .decode (latents ).sample
403
406
return frames
@@ -438,7 +441,6 @@ def check_inputs(
438
441
width ,
439
442
negative_prompt ,
440
443
callback_on_step_end_tensor_inputs ,
441
- video = None ,
442
444
latents = None ,
443
445
prompt_embeds = None ,
444
446
negative_prompt_embeds = None ,
@@ -494,9 +496,6 @@ def check_inputs(
494
496
f" { negative_prompt_embeds .shape } ."
495
497
)
496
498
497
- if video is not None and latents is not None :
498
- raise ValueError ("Only one of `video` or `latents` should be provided" )
499
-
500
499
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
501
500
def fuse_qkv_projections (self ) -> None :
502
501
r"""Enables fused QKV projections."""
@@ -584,18 +583,18 @@ def __call__(
584
583
585
584
Args:
586
585
image (`PipelineImageInput`):
587
- The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
586
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
588
587
prompt (`str` or `List[str]`, *optional*):
589
588
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
590
589
instead.
591
590
negative_prompt (`str` or `List[str]`, *optional*):
592
591
The prompt or prompts not to guide the image generation. If not defined, one has to pass
593
592
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
594
593
less than `1`).
595
- height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor ):
596
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
597
- width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor ):
598
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
594
+ height (`int`, *optional*, defaults to self.transformer .config.sample_height * self.vae_scale_factor_spatial ):
595
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
596
+ width (`int`, *optional*, defaults to self.transformer .config.sample_height * self.vae_scale_factor_spatial ):
597
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
599
598
num_frames (`int`, defaults to `48`):
600
599
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
601
600
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -665,20 +664,19 @@ def __call__(
665
664
if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
666
665
callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
667
666
668
- height = height or self .transformer .config .sample_size * self .vae_scale_factor_spatial
669
- width = width or self .transformer .config .sample_size * self .vae_scale_factor_spatial
670
667
num_videos_per_prompt = 1
671
668
672
669
# 1. Check inputs. Raise error if not correct
673
670
self .check_inputs (
674
- image ,
675
- prompt ,
676
- height ,
677
- width ,
678
- negative_prompt ,
679
- callback_on_step_end_tensor_inputs ,
680
- prompt_embeds ,
681
- negative_prompt_embeds ,
671
+ image = image ,
672
+ prompt = prompt ,
673
+ height = height ,
674
+ width = width ,
675
+ negative_prompt = negative_prompt ,
676
+ callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
677
+ latents = latents ,
678
+ prompt_embeds = prompt_embeds ,
679
+ negative_prompt_embeds = negative_prompt_embeds ,
682
680
)
683
681
self ._guidance_scale = guidance_scale
684
682
self ._interrupt = False
0 commit comments