Skip to content

Commit

Permalink
add ckpt load hook
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyixiao18 committed Jun 16, 2023
1 parent d679318 commit bded5de
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
9 changes: 2 additions & 7 deletions configs/minigpt4/minigpt-4_vicuna-7b_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,8 @@
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
),
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path= # noqa
'/mnt/petrelfs/share_data/liuyuan/llm_weights/vicuna_weights_7b'),
tokenizer=dict(
type='LlamaTokenizer',
name_or_path= # noqa
'/mnt/petrelfs/share_data/liuyuan/llm_weights/vicuna_weights_7b'),
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
task='caption',
prompt_template='###Human: {} ###Assistant: ',
raw_prompts=[
Expand Down
21 changes: 19 additions & 2 deletions mmpretrain/models/multimodal/minigpt4/minigpt4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
import re
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -106,8 +107,7 @@ def __init__(self,
from mmengine.runner.checkpoint import CheckpointLoader
state_dict = CheckpointLoader.load_checkpoint(
q_former_model_weight)['state_dict']
incompatible_keys = self.load_state_dict(state_dict, strict=False)
logger.info(incompatible_keys)
self.load_state_dict(state_dict, strict=False)

if freeze_q_former:
for name, param in self.q_former.named_parameters():
Expand Down Expand Up @@ -154,6 +154,9 @@ def __init__(self,
temperature=1.0,
**generation_cfg)

if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)

def encode_img(self,
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to encode the images."""
Expand Down Expand Up @@ -362,3 +365,17 @@ def forward(
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')

@staticmethod
def _load_llama_proj_hook(module, incompatible_keys):
"""Avoid warning missing keys except LLaMA projection keys."""
proj_patterns = [
'vision_encoder.*',
'ln_vision.*',
'q_former.*',
'query_tokens',
'llama_model.*',
]
for key in list(incompatible_keys.missing_keys):
if any(re.match(pattern, key) for pattern in proj_patterns):
incompatible_keys.missing_keys.remove(key)

0 comments on commit bded5de

Please # to comment.