-
Notifications
You must be signed in to change notification settings - Fork 581
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
增加grpo多次工具调用训练 #3503
base: main
Are you sure you want to change the base?
增加grpo多次工具调用训练 #3503
Conversation
数据集可以放在modelscope上,然后使用model_id进行使用嘛 然后 最外层目录的文件 放置在examples/train/grpo中单读建立个文件夹,然后放置在里面,然后写给文档(最佳实践)来介绍一下不 |
lint过一下,会进行代码的整理 |
好的 |
把数据集上传到了modelscope,然后新增一个最佳实践多轮工具调用实践 |
examples/train/rft/rft.py
Outdated
@@ -22,7 +22,8 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): | |||
for device in range(device_count): | |||
sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample ' | |||
f'--model {model} --model_type {model_type} ' | |||
f'--dataset {" ".join(dataset)} ' | |||
f'--dataset {' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有语法问题,请检查一下
scripts/benchmark/exp_utils.py
Outdated
@@ -122,7 +122,7 @@ def run(self, exp: Experiment): | |||
exp.runtime = runtime | |||
envs = deepcopy(runtime.get('env', {})) | |||
envs.update(os.environ) | |||
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}') | |||
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查+1
期待调用工具的GRPO支持 |
test_grpo_tool.py:训练测试脚本
math_tool.py:测试的工具,定义新运算,接口主要是判断是否继续和给格式奖励,还有online的result输入
相关数据集也放在目录中,比较混乱,主要还是修改gpro_trainer.py
GRPO args需要新增参数:
is_reward_tool_call:是否累加计算每个tool_call的格式奖励,但应该设置上限,否则可能会学到无限调用,但不输出正确答案。
tool_call_weight:tool_call_format奖励的权重