Skip to content

增加grpo多次工具调用训练 #3503

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ repos:
exclude: thirdparty/|tests/run.py
- id: requirements-txt-fixer
exclude: thirdparty/|tests/run.py
- id: double-quote-string-fixer
exclude: thirdparty/|tests/run.py
- id: check-merge-conflict
exclude: thirdparty/|tests/run.py
- id: fix-encoding-pragma
Expand Down
2 changes: 0 additions & 2 deletions .pre-commit-config_local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ repos:
exclude: thirdparty/|tests/run.py
- id: end-of-file-fixer
exclude: thirdparty/
- id: requirements-txt-fixer
exclude: thirdparty/|tests/run.py
- id: double-quote-string-fixer
exclude: thirdparty/|tests/run.py
- id: check-merge-conflict
Expand Down
2 changes: 1 addition & 1 deletion swift/plugin/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,5 +383,5 @@ def __call__(self, completions, **kwargs) -> List[float]:
'format': Format,
'react_format': ReActFormat,
'cosine': CosineReward,
'repetition': RepetitionPenalty,
'repetition': RepetitionPenalty
}
10 changes: 10 additions & 0 deletions swift/plugin/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Tuple, Any, Optional


class TOOL_CALL:

def __call__(self, completion: str) -> Tuple[Any, bool, Optional[float]]:
raise NotImplementedError


tools = {}
5 changes: 4 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from dataclasses import dataclass
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union, Callable

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -104,6 +104,9 @@ class GRPOArgumentsMixin:
offload_optimizer: bool = False
offload_model: bool = False
gc_collect_after_offload: bool = False
is_reward_tool_call: bool = True #是否额外单独计算每个tool call的format得分
tool_call_weight: float = 1.0
tool_call: str = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Optional
from typing import List, Optional, Callable

from trl import CPOConfig as HfCPOConfig
from trl import DPOConfig as HfDPOConfig
Expand Down
454 changes: 389 additions & 65 deletions swift/trainers/rlhf_trainer/grpo_trainer.py

Large diffs are not rendered by default.