diff --git a/export_state_dict_checkpoint.py b/export_state_dict_checkpoint.py index d0c9dbb7..c1d54f89 100644 --- a/export_state_dict_checkpoint.py +++ b/export_state_dict_checkpoint.py @@ -21,7 +21,12 @@ torch_dtype=torch.float16, ) -lora_model.eval() # merge weights +# merge weights +for layer in lora_model.base_model.model.model.layers: + layer.self_attn.q_proj.merge_weights = True + layer.self_attn.v_proj.merge_weights = True + +lora_model.train(False) lora_model_sd = lora_model.state_dict()