Skip to content

feature(whl): add rlhf pipeline. #748

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Conversation

kxzxvbk
Copy link
Contributor

@kxzxvbk kxzxvbk commented Nov 6, 2023

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added enhancement New feature or request algo Add new algorithm or improve old one labels Nov 6, 2023
@@ -18,6 +19,7 @@
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
from ..framework.middleware.collector import ChatCollector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge it into ding.framework

"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why indent here


def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
"""
Filter a distribution of logits using nucleus (top-p) filtering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comments add add unittest

if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[:, :min_topk] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..., :min_topk

@@ -1,4 +1,7 @@
from typing import Union, Dict, Optional

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these modifications to a new single file: lm_vac.py


def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define opt here

else:
logits = self.reward_head(output.last_hidden_state).squeeze(-1)

return (logits, )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why return a tuple here

self._init_flag = False

def reset(self):
self.last_batch = next(self.generator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to restrat generatore here?


class LlamaRewardModel(LlamaForCausalLM):

def __init__(self, config, opt, tokenizer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the creation of tokenizer insides the constructor of RM?

@@ -0,0 +1,50 @@
from easydict import EasyDict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to dizoo/chat/entry

Copy link

codecov bot commented Jan 3, 2024

Codecov Report

Attention: 252 lines in your changes are missing coverage. Please review.

Comparison is base (d7a61c2) 76.78% compared to head (f3a8245) 76.83%.

Files Patch % Lines
ding/model/template/lm_vac.py 20.00% 92 Missing ⚠️
ding/policy/ppof.py 5.74% 82 Missing ⚠️
ding/framework/middleware/collector.py 15.62% 27 Missing ⚠️
ding/rl_utils/gae.py 11.11% 16 Missing ⚠️
ding/reward_model/language_reward_model.py 31.57% 13 Missing ⚠️
ding/bonus/ppof.py 0.00% 12 Missing ⚠️
ding/bonus/config.py 0.00% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   76.78%   76.83%   +0.04%     
==========================================
  Files         671      674       +3     
  Lines       53196    53935     +739     
==========================================
+ Hits        40847    41440     +593     
- Misses      12349    12495     +146     
Flag Coverage Δ
unittests 76.83% <20.50%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
algo Add new algorithm or improve old one enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants