Skip to content

Commit 8fb4771

Browse files
committed
[Cog] some minor fixes and nits (#9466)
* fix positional arguments in check_inputs(). * add video and latetns to check_inputs(). * prep latents_in_channels. * quality * multiple fixes. * fix
1 parent 53786f5 commit 8fb4771

File tree

3 files changed

+65
-61
lines changed

3 files changed

+65
-61
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def __init__(
188188
self.vae_scale_factor_temporal = (
189189
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
190190
)
191+
self.vae_scaling_factor_image = (
192+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
193+
)
191194

192195
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
193196

@@ -317,18 +320,19 @@ def encode_prompt(
317320
def prepare_latents(
318321
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
319322
):
323+
if isinstance(generator, list) and len(generator) != batch_size:
324+
raise ValueError(
325+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
326+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
327+
)
328+
320329
shape = (
321330
batch_size,
322331
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
323332
num_channels_latents,
324333
height // self.vae_scale_factor_spatial,
325334
width // self.vae_scale_factor_spatial,
326335
)
327-
if isinstance(generator, list) and len(generator) != batch_size:
328-
raise ValueError(
329-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
330-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
331-
)
332336

333337
if latents is None:
334338
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -341,7 +345,7 @@ def prepare_latents(
341345

342346
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
343347
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
344-
latents = 1 / self.vae.config.scaling_factor * latents
348+
latents = 1 / self.vae_scaling_factor_image * latents
345349

346350
frames = self.vae.decode(latents).sample
347351
return frames
@@ -510,10 +514,10 @@ def __call__(
510514
The prompt or prompts not to guide the image generation. If not defined, one has to pass
511515
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
512516
less than `1`).
513-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
514-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
515-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
516-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
517+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
518+
The height in pixels of the generated image. This is set to 480 by default for the best results.
519+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
520+
The width in pixels of the generated image. This is set to 720 by default for the best results.
517521
num_frames (`int`, defaults to `48`):
518522
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
519523
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -587,8 +591,6 @@ def __call__(
587591
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
588592
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
589593

590-
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
591-
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
592594
num_videos_per_prompt = 1
593595

594596
# 1. Check inputs. Raise error if not correct

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

+25-27
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def __init__(
207207
self.vae_scale_factor_temporal = (
208208
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
209209
)
210+
self.vae_scaling_factor_image = (
211+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
212+
)
210213

211214
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
212215

@@ -348,6 +351,12 @@ def prepare_latents(
348351
generator: Optional[torch.Generator] = None,
349352
latents: Optional[torch.Tensor] = None,
350353
):
354+
if isinstance(generator, list) and len(generator) != batch_size:
355+
raise ValueError(
356+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
357+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
358+
)
359+
351360
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
352361
shape = (
353362
batch_size,
@@ -357,12 +366,6 @@ def prepare_latents(
357366
width // self.vae_scale_factor_spatial,
358367
)
359368

360-
if isinstance(generator, list) and len(generator) != batch_size:
361-
raise ValueError(
362-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
363-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
364-
)
365-
366369
image = image.unsqueeze(2) # [B, C, F, H, W]
367370

368371
if isinstance(generator, list):
@@ -373,7 +376,7 @@ def prepare_latents(
373376
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
374377

375378
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
376-
image_latents = self.vae.config.scaling_factor * image_latents
379+
image_latents = self.vae_scaling_factor_image * image_latents
377380

378381
padding_shape = (
379382
batch_size,
@@ -397,7 +400,7 @@ def prepare_latents(
397400
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
398401
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
399402
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
400-
latents = 1 / self.vae.config.scaling_factor * latents
403+
latents = 1 / self.vae_scaling_factor_image * latents
401404

402405
frames = self.vae.decode(latents).sample
403406
return frames
@@ -438,7 +441,6 @@ def check_inputs(
438441
width,
439442
negative_prompt,
440443
callback_on_step_end_tensor_inputs,
441-
video=None,
442444
latents=None,
443445
prompt_embeds=None,
444446
negative_prompt_embeds=None,
@@ -494,9 +496,6 @@ def check_inputs(
494496
f" {negative_prompt_embeds.shape}."
495497
)
496498

497-
if video is not None and latents is not None:
498-
raise ValueError("Only one of `video` or `latents` should be provided")
499-
500499
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
501500
def fuse_qkv_projections(self) -> None:
502501
r"""Enables fused QKV projections."""
@@ -584,18 +583,18 @@ def __call__(
584583
585584
Args:
586585
image (`PipelineImageInput`):
587-
The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
586+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
588587
prompt (`str` or `List[str]`, *optional*):
589588
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
590589
instead.
591590
negative_prompt (`str` or `List[str]`, *optional*):
592591
The prompt or prompts not to guide the image generation. If not defined, one has to pass
593592
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
594593
less than `1`).
595-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
596-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
597-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
598-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
594+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
595+
The height in pixels of the generated image. This is set to 480 by default for the best results.
596+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
597+
The width in pixels of the generated image. This is set to 720 by default for the best results.
599598
num_frames (`int`, defaults to `48`):
600599
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
601600
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -665,20 +664,19 @@ def __call__(
665664
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
666665
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
667666

668-
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
669-
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
670667
num_videos_per_prompt = 1
671668

672669
# 1. Check inputs. Raise error if not correct
673670
self.check_inputs(
674-
image,
675-
prompt,
676-
height,
677-
width,
678-
negative_prompt,
679-
callback_on_step_end_tensor_inputs,
680-
prompt_embeds,
681-
negative_prompt_embeds,
671+
image=image,
672+
prompt=prompt,
673+
height=height,
674+
width=width,
675+
negative_prompt=negative_prompt,
676+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
677+
latents=latents,
678+
prompt_embeds=prompt_embeds,
679+
negative_prompt_embeds=negative_prompt_embeds,
682680
)
683681
self._guidance_scale = guidance_scale
684682
self._interrupt = False

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,16 @@ def __init__(
204204
self.register_modules(
205205
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
206206
)
207+
207208
self.vae_scale_factor_spatial = (
208209
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
209210
)
210211
self.vae_scale_factor_temporal = (
211212
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
212213
)
214+
self.vae_scaling_factor_image = (
215+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
216+
)
213217

214218
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
215219

@@ -351,6 +355,12 @@ def prepare_latents(
351355
latents: Optional[torch.Tensor] = None,
352356
timestep: Optional[torch.Tensor] = None,
353357
):
358+
if isinstance(generator, list) and len(generator) != batch_size:
359+
raise ValueError(
360+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
361+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
362+
)
363+
354364
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
355365

356366
shape = (
@@ -361,12 +371,6 @@ def prepare_latents(
361371
width // self.vae_scale_factor_spatial,
362372
)
363373

364-
if isinstance(generator, list) and len(generator) != batch_size:
365-
raise ValueError(
366-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
367-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
368-
)
369-
370374
if latents is None:
371375
if isinstance(generator, list):
372376
if len(generator) != batch_size:
@@ -382,7 +386,7 @@ def prepare_latents(
382386
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
383387

384388
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
385-
init_latents = self.vae.config.scaling_factor * init_latents
389+
init_latents = self.vae_scaling_factor_image * init_latents
386390

387391
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
388392
latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -396,7 +400,7 @@ def prepare_latents(
396400
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
397401
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
398402
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
399-
latents = 1 / self.vae.config.scaling_factor * latents
403+
latents = 1 / self.vae_scaling_factor_image * latents
400404

401405
frames = self.vae.decode(latents).sample
402406
return frames
@@ -589,10 +593,10 @@ def __call__(
589593
The prompt or prompts not to guide the image generation. If not defined, one has to pass
590594
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
591595
less than `1`).
592-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
593-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
594-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
595-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
596+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
597+
The height in pixels of the generated image. This is set to 480 by default for the best results.
598+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
599+
The width in pixels of the generated image. This is set to 720 by default for the best results.
596600
num_inference_steps (`int`, *optional*, defaults to 50):
597601
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
598602
expense of slower inference.
@@ -658,20 +662,20 @@ def __call__(
658662
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
659663
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
660664

661-
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
662-
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
663665
num_videos_per_prompt = 1
664666

665667
# 1. Check inputs. Raise error if not correct
666668
self.check_inputs(
667-
prompt,
668-
height,
669-
width,
670-
strength,
671-
negative_prompt,
672-
callback_on_step_end_tensor_inputs,
673-
prompt_embeds,
674-
negative_prompt_embeds,
669+
prompt=prompt,
670+
height=height,
671+
width=width,
672+
strength=strength,
673+
negative_prompt=negative_prompt,
674+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
675+
video=video,
676+
latents=latents,
677+
prompt_embeds=prompt_embeds,
678+
negative_prompt_embeds=negative_prompt_embeds,
675679
)
676680
self._guidance_scale = guidance_scale
677681
self._attention_kwargs = attention_kwargs

0 commit comments

Comments
 (0)