Skip to content

Commit

Permalink
allow bootstrap grpo
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 11, 2025
1 parent b9df810 commit 3680d55
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
22 changes: 19 additions & 3 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,21 @@ def apply_chat_template(

# Apply the chat template to the prompt, adding the generation prompt
if "prompt" in example:
last_role = example["prompt"][-1]["role"]
if last_role == "user":
add_generation_prompt = True
continue_final_message = False
elif last_role == "assistant":
add_generation_prompt = False
continue_final_message = True
else:
raise ValueError(f"Invalid role in the last message: {last_role}")
prompt = tokenizer.apply_chat_template(
example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True
example["prompt"],
tools=tools,
continue_final_message=continue_final_message,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)

# Apply the chat template to the entire prompt + completion
Expand Down Expand Up @@ -180,10 +193,13 @@ def maybe_apply_chat_template(
Returns:
`dict[str, str]`: The formatted example with the chat template applied.
Note:
This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
Notes:
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
`"text"`.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
if the last role is `"assistant"`, the final message is continued.
Example:
```python
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text

Expand Down

0 comments on commit 3680d55

Please # to comment.