-
Notifications
You must be signed in to change notification settings - Fork 398
feature(whl): add rlhf pipeline. #748
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
ding/bonus/ppof.py
Outdated
@@ -18,6 +19,7 @@ | |||
from .model import PPOFModel | |||
from .config import get_instance_config, get_instance_env, get_hybrid_shape | |||
from ding.bonus.common import TrainingReturn, EvalReturn | |||
from ..framework.middleware.collector import ChatCollector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge it into ding.framework
""" | ||
Overview: | ||
The class of the collector running by steps, including model inference and transition \ | ||
process. Use the `__call__` method to execute the whole collection process. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why indent here
ding/model/common/utils.py
Outdated
|
||
def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): | ||
""" | ||
Filter a distribution of logits using nucleus (top-p) filtering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
polish comments add add unittest
ding/model/common/utils.py
Outdated
if topp > 0: | ||
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) | ||
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp | ||
mask[:, :min_topk] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
..., :min_topk
ding/model/template/vac.py
Outdated
@@ -1,4 +1,7 @@ | |||
from typing import Union, Dict, Optional | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move these modifications to a new single file: lm_vac.py
|
||
def __init__(self, config, opt, tokenizer): | ||
super().__init__(config) | ||
self.opt = opt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why define opt here
else: | ||
logits = self.reward_head(output.last_hidden_state).squeeze(-1) | ||
|
||
return (logits, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why return a tuple here
self._init_flag = False | ||
|
||
def reset(self): | ||
self.last_batch = next(self.generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to restrat generatore here?
|
||
class LlamaRewardModel(LlamaForCausalLM): | ||
|
||
def __init__(self, config, opt, tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move the creation of tokenizer insides the constructor of RM?
@@ -0,0 +1,50 @@ | |||
from easydict import EasyDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move it to dizoo/chat/entry
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #748 +/- ##
==========================================
+ Coverage 76.78% 76.83% +0.04%
==========================================
Files 671 674 +3
Lines 53196 53935 +739
==========================================
+ Hits 40847 41440 +593
- Misses 12349 12495 +146
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Description
Related Issue
TODO
Check List