Skip to content

Commit

Permalink
MMBench supports Siglip
Browse files Browse the repository at this point in the history
  • Loading branch information
bobo0810 committed Mar 28, 2024
1 parent 62e7d80 commit a8cc1a6
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions xtuner/tools/mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from torch.utils.data import Dataset
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, CLIPImageProcessor,
CLIPVisionModel, GenerationConfig)
CLIPVisionModel, GenerationConfig,
SiglipImageProcessor, SiglipVisionModel
)

from xtuner.dataset.utils import decode_base64_to_image, expand2square
from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
Expand Down Expand Up @@ -330,10 +332,19 @@ def main():
'Please specify the `--visual-encoder`!')
visual_encoder_path = args.visual_encoder
with LoadWoInit():
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(
visual_encoder_path)
if 'clip' in visual_encoder_path:
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(
visual_encoder_path)
elif 'siglip' in visual_encoder_path:
visual_encoder = SiglipVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = SiglipImageProcessor.from_pretrained(
visual_encoder_path)
else:
raise f"Visual encoders not supported : {visual_encoder_path}"

master_print(f'Load visual_encoder from {visual_encoder_path}')

# load adapter
Expand Down Expand Up @@ -506,5 +517,4 @@ def main():


if __name__ == '__main__':

main()

0 comments on commit a8cc1a6

Please # to comment.