Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

I want to fine-tune a complete text encoder model, but it seems that the model trained by ft-B-train-OpenAI-CLIP-ViT-L-14.py is a visual encoder model. #16

Open
vxiaobai opened this issue Oct 23, 2024 · 8 comments

Comments

@vxiaobai
Copy link

First of all, thank you for your work. I have a question for you.
I want to fine-tune a complete text encoder model, but it seems that the model trained by ft-B-train-OpenAI-CLIP-ViT-L-14.py is a visual encoder model. How can I get the model of the pure text encoder ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors given in your HF?

@zer0int
Copy link
Owner

zer0int commented Oct 23, 2024

The fine-tune is actually a text-vision model, consisting of a text transformer AND a vision transformer. For the "TE only" / text encoder only models on my HuggingFace, I fine-tuned the entire CLIP model (text + vision) and then simply "detached" the vision transformer (i.e. delete the keys / associated parameters). CLIP's objective is in the name - Contrastive Language-Image Pretraining. Learning both text and image, optimizing for dot-product of matching pairs (high) vs. negative examples (low), is the objective / optimization goal. It needs both image and text to be a "CLIP", per definition.

So, the question is - what are you trying to archive? Or do you mean that you only want to train the text encoder, with a frozen visual encoder (no parameter updates)? In that case:

The vision transformer is visual.transformer.resblocks[i] (and visual.proj and so on), the text transformer is transformer.resblocks[i] (no 'visual').
Alas, to only train the text encoder parameters while keeping the visual encoder frozen (but still using a contrastive loss between text-image), you could use something like this:

def freeze_clip_selectively(model):
    for name, param in model.named_parameters():
        if any(key in name for key in [
            'visual'
        ]):
            param.requires_grad = False
        else:
            param.requires_grad = True

# in trainloop(), before "for epoch [...]":

freeze_clip_selectively(model)

@vxiaobai
Copy link
Author

For the "TE only" / text encoder only models on my HuggingFace, I fine-tuned the entire CLIP model (text + vision) and then simply "detached" the vision transformer (i.e. delete the keys / associated parameters).

Can you please give me the code for this, I want to use it with the flux model, I tested the text only encoder model you provided on HF and it works with the flux model, and now I want to train the CLIP model as a multi-lingual model, but I am not familiar with the steps to "separate" the vision transformer. I would like your help, thank you very much.
Also thank you very much for your code, I learned a lot about multimodality from it.

@zer0int
Copy link
Owner

zer0int commented Oct 23, 2024

I just committed Convert-for-HuggingFace-Spaces-etc - the folder contains all the scripts + documentation / how-to use. Please let me know if that works for you!

@vxiaobai
Copy link
Author

I just committed Convert-for-HuggingFace-Spaces-etc - the folder contains all the scripts + documentation / how-to use. Please let me know if that works for you!我刚刚提交了 Convert-for-HuggingFace-Spaces-etc - 该文件夹包含所有脚本+文档/如何使用。请告诉我这是否适合您!

Thank you very much. I think the code you provided is what I want, but I encountered some problems when converting. The error message is below. I would like to ask if you have encountered the same problem. I am trying to train several of your training programs separately, and then try each one:
state_dict = torch.load(opened_file, map_location="cpu")
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/clip/clip.py", line 129, in load
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
^ ... "/opt/conda/lib/python3.11/site-packages/torch/jit/_serialization.py", line 165, in load cpp_module = torch._C.import_ir_module_from_buffer( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: PytorchStreamReader failed locating file constant s.pkl: file not found During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/workspace/finetune_CLIP/CLIP-fine-tune/Convert-for-HuggingFace-Spaces-etc/convert_clip_original_pytorch_to_hf.py", line 156, in convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/workspace/finetune_CLIP/CLIP -fine-tune/Convert-for-HuggingFace-Spaces-etc/convert_clip_original_pytorch_to_hf.py", line 120, in convert_clip_checkpoint pt_model, _ = load(checkpoint_path, device="cpu", jit=False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/clip/clip.py", line 136, in load state_dict = torch.load(opened_file, map_location="cpu") ^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1114, in load return _legacy_load( ^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1114, line 1338, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ EOFError: Ran out of input

@zer0int
Copy link
Owner

zer0int commented Oct 24, 2024

Can you open /opt/conda/lib/python3.11/site-packages/clip/clip.py and edit line 129? Where it says:

model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() -> Change that to:

#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.load(opened_file, map_location="cpu").eval()

I can't reproduce your error, but somebody else reported the same; I am assuming it might be related to the venv / conda, and trying to load a torch jit scripted archive. I don't use a venv.

However, torch.jit is just for "interoperability, speed and production environments", so it's not needed, and we can just put the map_location on CPU in any case.

If that doesn't work, my other random guess at a fix (as I can't reproduce the problem):
Can you use my ft-C-convert-for-SDXL-comfyUI-OpenAI-CLIP.py script (converts the full model to a state_dict), and try loading this converted model for the conversion instead?

@vxiaobai
Copy link
Author

Can you open /opt/conda/lib/python3.11/site-packages/clip/clip.py and edit line 129? Where it says:

model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() -> Change that to:

#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.load(opened_file, map_location="cpu").eval()

I can't reproduce your error, but somebody else reported the same; I am assuming it might be related to the venv / conda, and trying to load a torch jit scripted archive. I don't use a venv.

However, torch.jit is just for "interoperability, speed and production environments", so it's not needed, and we can just put the map_location on CPU in any case.

If that doesn't work, my other random guess at a fix (as I can't reproduce the problem): Can you use my ft-C-convert-for-SDXL-comfyUI-OpenAI-CLIP.py script (converts the full model to a state_dict), and try loading this converted model for the conversion instead?

I learned the cause of this error in other forums and tried to solve the problem with it. It worked, but I'm not sure if it was the final factor.
If you save the model with torch.save(model, model_path) , then load it with model = torch.load(opened_file, map_location="cpu").eval() . If you need to load it with model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() , then save the model with script_model = torch.jit.trace(model, (images, texts)) script_model.save("model.jit.pt") .
Hope this helps with your work library.

@zer0int
Copy link
Owner

zer0int commented Oct 26, 2024

Thank you for the suggestion, and glad you got it to work! I'll try it and consider implementing as a Bool to switch - to True if you want to script the model, else save a normal torch.save, with my next update. 👍

@zer0int
Copy link
Owner

zer0int commented Nov 11, 2024

I updated the code with a new model saver; you can now choose to either save as GmP (legacy behavior) or directly convert back to .weight (original OpenAI/CLIP; no extra script for conversion needed anymore!). Plus, you can save the model as 1. a full model object (legacy behavior) or 2. a state_dict or 3. a torch.jit.trace() -- or all of those combined.

Hope it's useful to you! 👍

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants