21
21
from pathlib import Path
22
22
from typing import List , Optional , Tuple , Union
23
23
24
+ import numpy as np
24
25
import torch
26
+ import torchvision .transforms as TT
25
27
import transformers
26
28
from accelerate import Accelerator
27
29
from accelerate .logging import get_logger
28
30
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
29
31
from huggingface_hub import create_repo , upload_folder
30
32
from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
31
33
from torch .utils .data import DataLoader , Dataset
32
- from torchvision import transforms
34
+ from torchvision .transforms import InterpolationMode
35
+ from torchvision .transforms .functional import resize
33
36
from tqdm .auto import tqdm
34
37
from transformers import AutoTokenizer , T5EncoderModel , T5Tokenizer
35
38
36
39
import diffusers
37
40
from diffusers import AutoencoderKLCogVideoX , CogVideoXDPMScheduler , CogVideoXPipeline , CogVideoXTransformer3DModel
41
+ from diffusers .image_processor import VaeImageProcessor
38
42
from diffusers .models .embeddings import get_3d_rotary_pos_embed
39
43
from diffusers .optimization import get_scheduler
40
44
from diffusers .pipelines .cogvideo .pipeline_cogvideox import get_resize_crop_region_for_grid
@@ -214,6 +218,12 @@ def get_args():
214
218
default = 720 ,
215
219
help = "All input videos are resized to this width." ,
216
220
)
221
+ parser .add_argument (
222
+ "--video_reshape_mode" ,
223
+ type = str ,
224
+ default = "center" ,
225
+ help = "All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']" ,
226
+ )
217
227
parser .add_argument ("--fps" , type = int , default = 8 , help = "All input videos will be used at this FPS." )
218
228
parser .add_argument (
219
229
"--max_num_frames" , type = int , default = 49 , help = "All input videos will be truncated to these many frames."
@@ -413,6 +423,7 @@ def __init__(
413
423
video_column : str = "video" ,
414
424
height : int = 480 ,
415
425
width : int = 720 ,
426
+ video_reshape_mode : str = "center" ,
416
427
fps : int = 8 ,
417
428
max_num_frames : int = 49 ,
418
429
skip_frames_start : int = 0 ,
@@ -429,6 +440,7 @@ def __init__(
429
440
self .video_column = video_column
430
441
self .height = height
431
442
self .width = width
443
+ self .video_reshape_mode = video_reshape_mode
432
444
self .fps = fps
433
445
self .max_num_frames = max_num_frames
434
446
self .skip_frames_start = skip_frames_start
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):
532
544
533
545
return instance_prompts , instance_videos
534
546
547
+ def _resize_for_rectangle_crop (self , arr ):
548
+ image_size = self .height , self .width
549
+ reshape_mode = self .video_reshape_mode
550
+ if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
551
+ arr = resize (
552
+ arr ,
553
+ size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
554
+ interpolation = InterpolationMode .BICUBIC ,
555
+ )
556
+ else :
557
+ arr = resize (
558
+ arr ,
559
+ size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
560
+ interpolation = InterpolationMode .BICUBIC ,
561
+ )
562
+
563
+ h , w = arr .shape [2 ], arr .shape [3 ]
564
+ arr = arr .squeeze (0 )
565
+
566
+ delta_h = h - image_size [0 ]
567
+ delta_w = w - image_size [1 ]
568
+
569
+ if reshape_mode == "random" or reshape_mode == "none" :
570
+ top = np .random .randint (0 , delta_h + 1 )
571
+ left = np .random .randint (0 , delta_w + 1 )
572
+ elif reshape_mode == "center" :
573
+ top , left = delta_h // 2 , delta_w // 2
574
+ else :
575
+ raise NotImplementedError
576
+ arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
577
+ return arr
578
+
535
579
def _preprocess_data (self ):
536
580
try :
537
581
import decord
@@ -542,15 +586,14 @@ def _preprocess_data(self):
542
586
543
587
decord .bridge .set_bridge ("torch" )
544
588
545
- videos = []
546
- train_transforms = transforms .Compose (
547
- [
548
- transforms .Lambda (lambda x : x / 255.0 * 2.0 - 1.0 ),
549
- ]
589
+ progress_dataset_bar = tqdm (
590
+ range (0 , len (self .instance_video_paths )),
591
+ desc = "Loading progress resize and crop videos" ,
550
592
)
593
+ videos = []
551
594
552
595
for filename in self .instance_video_paths :
553
- video_reader = decord .VideoReader (uri = filename .as_posix (), width = self . width , height = self . height )
596
+ video_reader = decord .VideoReader (uri = filename .as_posix ())
554
597
video_num_frames = len (video_reader )
555
598
556
599
start_frame = min (self .skip_frames_start , video_num_frames )
@@ -576,10 +619,16 @@ def _preprocess_data(self):
576
619
assert (selected_num_frames - 1 ) % 4 == 0
577
620
578
621
# Training transforms
579
- frames = frames .float ()
580
- frames = torch .stack ([train_transforms (frame ) for frame in frames ], dim = 0 )
581
- videos .append (frames .permute (0 , 3 , 1 , 2 ).contiguous ()) # [F, C, H, W]
622
+ frames = (frames - 127.5 ) / 127.5
623
+ frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
624
+ progress_dataset_bar .set_description (
625
+ f"Loading progress Resizing video from { frames .shape [2 ]} x{ frames .shape [3 ]} to { self .height } x{ self .width } "
626
+ )
627
+ frames = self ._resize_for_rectangle_crop (frames )
628
+ videos .append (frames .contiguous ()) # [F, C, H, W]
629
+ progress_dataset_bar .update (1 )
582
630
631
+ progress_dataset_bar .close ()
583
632
return videos
584
633
585
634
@@ -694,8 +743,13 @@ def log_validation(
694
743
695
744
videos = []
696
745
for _ in range (args .num_validation_videos ):
697
- video = pipe (** pipeline_args , generator = generator , output_type = "np" ).frames [0 ]
698
- videos .append (video )
746
+ pt_images = pipe (** pipeline_args , generator = generator , output_type = "pt" ).frames [0 ]
747
+ pt_images = torch .stack ([pt_images [i ] for i in range (pt_images .shape [0 ])])
748
+
749
+ image_np = VaeImageProcessor .pt_to_numpy (pt_images )
750
+ image_pil = VaeImageProcessor .numpy_to_pil (image_np )
751
+
752
+ videos .append (image_pil )
699
753
700
754
for tracker in accelerator .trackers :
701
755
phase_name = "test" if is_final_validation else "validation"
@@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
1171
1225
video_column = args .video_column ,
1172
1226
height = args .height ,
1173
1227
width = args .width ,
1228
+ video_reshape_mode = args .video_reshape_mode ,
1174
1229
fps = args .fps ,
1175
1230
max_num_frames = args .max_num_frames ,
1176
1231
skip_frames_start = args .skip_frames_start ,
@@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
1179
1234
id_token = args .id_token ,
1180
1235
)
1181
1236
1182
- def encode_video (video ):
1237
+ def encode_video (video , bar ):
1238
+ bar .update (1 )
1183
1239
video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
1184
1240
video = video .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, F, H, W]
1185
1241
latent_dist = vae .encode (video ).latent_dist
1186
1242
return latent_dist
1187
1243
1188
- train_dataset .instance_videos = [encode_video (video ) for video in train_dataset .instance_videos ]
1244
+ progress_encode_bar = tqdm (
1245
+ range (0 , len (train_dataset .instance_videos )),
1246
+ desc = "Loading Encode videos" ,
1247
+ )
1248
+ train_dataset .instance_videos = [
1249
+ encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos
1250
+ ]
1251
+ progress_encode_bar .close ()
1189
1252
1190
1253
def collate_fn (examples ):
1191
1254
videos = [example ["instance_video" ].sample () * vae .config .scaling_factor for example in examples ]
0 commit comments