Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

DDIM inversion on the pretrained T2V model #131

Open
qpc1611094 opened this issue Jun 27, 2024 · 0 comments
Open

DDIM inversion on the pretrained T2V model #131

qpc1611094 opened this issue Jun 27, 2024 · 0 comments

Comments

@qpc1611094
Copy link

qpc1611094 commented Jun 27, 2024

I have tried DDIM inversion on the modelscope T2V, but get some abnormal results.
I apply as follows:

1) get the original video latent:

def load_video_frames(autoencoder, vid_path, train_trans, max_frames=16, double_frames_sr=False):
    capture = cv2.VideoCapture(vid_path)
    _fps = capture.get(cv2.CAP_PROP_FPS)
    sample_fps = _fps
    _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
    stride = round(_fps / sample_fps)
    cover_frame_num = (stride * max_frames)
    if _total_frame_num < cover_frame_num + 5:
        start_frame = 0
        end_frame = _total_frame_num
    else:
        start_frame = 0
        end_frame = _total_frame_num
    
    pointer = 0
    frame_list = []
    while len(frame_list) < max_frames:
        ret, frame = capture.read()
        pointer += 1 
        if (not ret) or (frame is None): break
        if pointer < start_frame: continue
        if pointer >= _total_frame_num + 1: break
        if (pointer - start_frame) % stride == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            if double_frames_sr:
                frame_list.append(frame)
            frame_list.append(frame)
    
    capture.release()
    video_data = train_trans(frame_list)

    video_data = torch.nn.functional.interpolate(video_data, size=(256, 448), mode='bilinear')
    video_data = video_data.unsqueeze(0)
    video_data = video_data.cuda()

    batch_size, frames_num, _, _, _ = video_data.shape
    video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
    video_data_list = torch.chunk(video_data, video_data.shape[0]//2, dim=0)

    with torch.no_grad():
        decode_data = []
        for vd_data in video_data_list:
            tmp = autoencoder.encode_firsr_stage(vd_data, cfg.scale_factor).detach()
            decode_data.append(tmp)
        video_data_feature = torch.cat(decode_data, dim=0)
        video_data_feature = rearrange(video_data_feature, '(b f) c h w -> b c f h w', b = batch_size)
    return video_data_feature

train_trans = data.Compose([
        data.ToTensor(),
        data.Normalize(mean=cfg.mean, std=cfg.std)])
video_data_feature = load_video_frames(autoencoder, cfg.test_video_path, train_trans)

2)obtain the noise latent by DDIM inversion:

model_kwargs=[{'y': y_words, 'fps': fps_tensor},
                    {'y': zero_y_negative, 'fps': fps_tensor}]
noised_vid_feat = diffusion.ddim_reverse_sample_loop(video_data_feature,
                                                     model.eval(),
                                                     model_kwargs=model_kwargs,
                                                     clamp=None,
                                                     percentile=None,
                                                     guide_scale=cfg.guide_scale,
                                                     ddim_timesteps=cfg.ddim_timesteps)

3)reconstruct the original video:

video_reconstruct = diffusion.ddim_sample_loop(
                    noise=noised_vid_feat,
                    model=model.eval(),
                    model_kwargs=model_kwargs,
                    guide_scale=cfg.guide_scale,
                    ddim_timesteps=cfg.ddim_timesteps,
                    eta=0.0)
video_reconstruct = 1. / cfg.scale_factor * video_reconstruct
video_reconstruct = rearrange(video_reconstruct, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, video_reconstruct.shape[0])
video_reconstruct_list = torch.chunk(video_reconstruct, video_reconstruct.shape[0]//chunk_size, dim=0)
decode_reconstruct = []
for vd_data in video_reconstruct_list:
    gen_frames = autoencoder.decode(vd_data)
    decode_reconstruct.append(gen_frames)
video_reconstruct = torch.cat(decode_reconstruct, dim=0)
video_reconstruct = rearrange(video_reconstruct, '(b f) c h w -> b c f h w', b = 1)
save_i2vgen_video_safe(local_path, video_reconstruct.cpu(), captions, cfg.mean, cfg.std, text_size)

the original video is:
rank_02_01_0003_A_horse_running_on_the_road

However, the reconstruct video is completely collapsed as:
Unknown

I guess maybe this problem is about the scale_factor, since when I use the cfg.scale_factor=1.0, the result seems better:
rank_01_00_reconstruct

Looking forward to your reply, very thanks!

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant