-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
add Moslora #9331
add Moslora #9331
Conversation
Thanks for your contribution! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #9331 +/- ##
===========================================
+ Coverage 52.91% 53.10% +0.19%
===========================================
Files 679 685 +6
Lines 108433 108855 +422
===========================================
+ Hits 57378 57810 +432
+ Misses 51055 51045 -10 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
llm/tools/merge_lora_params.py
Outdated
@@ -86,16 +86,25 @@ def lora_process(name, lora_config, state_dict, device, lora_state_dict=None): | |||
return | |||
|
|||
weight = state_dict.pop(name + ".weight") | |||
lora_use_mixer = (lora_state_dict is not None and name + ".lora_AB" in lora_state_dict) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议lora_use_mixer直接从lora_config读取不要用state dict的key判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resoved
llm/tools/merge_lora_params.py
Outdated
lora_AB = lora_AB.astype("float32") | ||
out = (weight + lora_A @ lora_AB @ lora_B * scaling).astype("bfloat16") | ||
else: | ||
out = (weight + lora_A @ lora_B * scaling).astype("bfloat16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里astype改成lora_config里的dtype,原本写的有问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
@@ -209,6 +209,9 @@ class ModelArgument: | |||
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) | |||
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) | |||
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) | |||
lora_use_mixer: bool = field( | |||
default=False, metadata={"help": "Whether to use MosLoRA: https://arxiv.org/pdf/2406.11909"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
顺带更新一下文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
@@ -443,6 +443,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) | |||
pissa=lora_config.pissa, | |||
bias_attr=False if module.bias is None else None, | |||
use_quick_lora=lora_config.use_quick_lora, | |||
lora_use_mixer=lora_config.lora_use_mixer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在loramodel的init里if (tensor_parallel_degree >1 or pipeline_parallel_degree > 1) and lora_config.lora_use_mixer: raise NotImplementError("xxx")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
add moslora at peft/lora
Description
add moslora method