Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add phi3 conversation template support #788

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/examples/DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Conversations should be formatted before feeding into the model. As of now, we'v
| `chatml` | `<\|im_start\|>system`<br>`You are a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`Who are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I am a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`How old are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|im_end\|>`<br> | [Link](./supported_conversation_template.md#chatml) |
| `llama3` | `<\|begin_of_text\|><\|start_header_id\|>system<\|end_header_id\|>`<br><br>`You are a chatbot developed by LMFlow team.<\|eot_id\|><\|start_header_id\|>user<\|end_header_id\|>`<br><br>`Who are you?<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>`<br><br>`I am a chatbot developed by LMFlow team.<\|eot_id\|><\|start_header_id\|>user<\|end_header_id\|>`<br><br>`How old are you?<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>`<br><br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|eot_id\|>` | [Link](./supported_conversation_template.md#llama-3) |
| `llama2` | `<s>[INST] <<SYS>>`<br>`You are a chatbot developed by LMFlow team.`<br>`<</SYS>>`<br><br>`Who are you? [/INST] I am a chatbot developed by LMFlow team.</s><s>[INST] How old are you? [/INST] I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.</s>` | [Link](./supported_conversation_template.md#llama-2) |
| `phi3` | `<s><\|system\|>`<br>`You are a chatbot developed by LMFlow team.<\|end\|>`<br>`<\|user\|>\nWho are you?<\|end\|>`<br>`<\|assistant\|>`<br>`I am a chatbot developed by LMFlow team.<\|end\|>`<br>`<\|user\|>`<br>`How old are you?<\|end\|>`<br>`<\|assistant\|>`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|end\|>`<br>`<\|endoftext\|>` | [Link](./supported_conversation_template.md#phi-3) |
| `qwen2` | `<\|im_start\|>system`<br>`You are a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`Who are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I am a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`How old are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|im_end\|>`<br> | [Link](./supported_conversation_template.md#qwen-2) |

Passing the template name to the `--conversation_template` argument to apply the corresponding conversation template:
Expand Down
34 changes: 34 additions & 0 deletions docs/source/examples/supported_conversation_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- [Mixtral 8x22B](#mixtral-8x22b)
- [Mixtral 8x7B](#mixtral-8x7b)
- [Qwen-2](#qwen-2)
- [Phi-3](#phi-3)
- [Yi](#yi)


Expand Down Expand Up @@ -174,6 +175,39 @@ The conversation template for Mixtral 8x7B is slightly different from the templa
```


## Phi-3
**With a system message**
```
<s><|system|>\n{{system_message}}<|end|>\n<|user|>\n{{user_message_0}}<|end|>\n<|endoftext|>
```

**Without a system message**
```
<s><|user|>\n{{user_message_0}}<|end|>\n<|endoftext|>
```

**A complete conversation**
```
<s><|system|>\n{{system_message}}<|end|>\n<|user|>\n{{user_message_0}}<|end|>\n<|assistant|>\n{{assistant_reply_0}}<|end|>\n<|endoftext|>
```

**Multiple rounds**
```
<s><|system|>\n{{system_message}}<|end|>\n<|user|>\n{{user_message_0}}<|end|>\n<|assistant|>\n{{assistant_reply_0}}<|end|>\n<|user|>\n{{user_message_1}}<|end|>\n<|assistant|>\n{{assistant_reply_1}}<|end|>\n<|endoftext|>
```

**jinja template**
[[Reference]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/3a811845d89f3c1b3f41b341d0f9f05104769f35/tokenizer_config.json#L338)
```
{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}
```

**Filled Example**
```
<s><|system|>\nYou are a chatbot developed by LMFlow team.<|end|>\n<|user|>\nWho are you?<|end|>\n<|assistant|>\nI am a chatbot developed by LMFlow team.<|end|>\n<|user|>\nHow old are you?<|end|>\n<|assistant|>\nI don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<|end|>\n<|endoftext|>
```


## Qwen-2
**With a system message**
```
Expand Down
3 changes: 3 additions & 0 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
ChatMLConversationTemplate,
Llama2ConversationTemplate,
Llama3ConversationTemplate,
Phi3ConversationTemplate,
Qwen2ConversationTemplate,
EmptyConversationTemplateWithoutSpecialTokens
)
Expand Down Expand Up @@ -476,6 +477,8 @@ def tokenize(self, dataset, add_special_tokens=True, *args, **kwargs):
conversation_template = ChatMLConversationTemplate()
elif data_args.conversation_template == 'qwen2':
conversation_template = Qwen2ConversationTemplate()
elif data_args.conversation_template == 'phi3':
conversation_template = Phi3ConversationTemplate()
elif data_args.conversation_template == 'empty':
conversation_template = EmptyConversationTemplate()
elif data_args.conversation_template == 'empty_no_special_tokens':
Expand Down
54 changes: 53 additions & 1 deletion src/lmflow/utils/conversation_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ConversationTemplate:
tools_formatter: Optional[Formatter] = None
separator: Optional[TemplateComponent] = None
special_starter: Optional[TemplateComponent] = None
special_stopper: Optional[TemplateComponent] = None

def __post_init__(self):
if self.separator:
Expand Down Expand Up @@ -91,6 +92,8 @@ def encode_conversation(
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)

return encoded_pairs

Expand Down Expand Up @@ -199,7 +202,12 @@ def add_special_starter(
if self.special_starter.type == 'string':
special_starter_ids = tokenizer.encode(self.special_starter.content, add_special_tokens=False)
elif self.special_starter.type == 'token':
special_starter_ids = self._ensure_id_list(tokenizer.convert_tokens_to_ids(self.special_starter.content))
if self.special_starter.content == 'bos_token':
special_starter_ids = [tokenizer.bos_token_id]
elif self.special_starter.content == 'eos_token':
special_starter_ids = [tokenizer.eos_token_id]
else:
special_starter_ids = self._ensure_id_list(tokenizer.convert_tokens_to_ids(self.special_starter.content))
elif self.special_starter.type == 'token_id':
special_starter_ids = self._ensure_id_list(self.special_starter.content)
else:
Expand All @@ -209,6 +217,29 @@ def add_special_starter(

return encoded_pairs

def add_special_stopper(
self,
encoded_pairs: Sequence[Tuple[List[int], List[int]]],
tokenizer: PreTrainedTokenizer
) -> Sequence[Tuple[List[int], List[int]]]:
if self.special_stopper.type == 'string':
special_stopper_ids = tokenizer.encode(self.special_stopper.content, add_special_tokens=False)
elif self.special_stopper.type == 'token':
if self.special_stopper.content == 'bos_token':
special_stopper_ids = [tokenizer.bos_token_id]
elif self.special_stopper.content == 'eos_token':
special_stopper_ids = [tokenizer.eos_token_id]
else:
special_stopper_ids = self._ensure_id_list(tokenizer.convert_tokens_to_ids(self.special_stopper.content))
elif self.special_stopper.type == 'token_id':
special_stopper_ids = self._ensure_id_list(self.special_stopper.content)
else:
raise ValueError(f"Component type {self.special_stopper.type} cannot be used as a special stopper.")

encoded_pairs[-1] = (encoded_pairs[-1][0], encoded_pairs[-1][1] + special_stopper_ids)

return encoded_pairs

def _ensure_id_list(self, obj: Union[int, List[int]]) -> List[int]:
'''Make sure the object is a list of integers. Useful for handling token ids.
'''
Expand Down Expand Up @@ -345,6 +376,27 @@ def _encode(
return res_all


@dataclass
class Phi3ConversationTemplate(ConversationTemplate):
user_formatter: Formatter = StringFormatter(
template=[
TemplateComponent(type='string', content='<|user|>\n{{content}}<|end|>\n')
]
)
assistant_formatter: Formatter = StringFormatter(
template=[
TemplateComponent(type='string', content='<|assistant|>\n{{content}}<|end|>\n')
]
)
system_formatter: Formatter = StringFormatter(
template=[
TemplateComponent(type='string', content='<|system|>\n{{content}}<|end|>\n')
]
)
special_starter: TemplateComponent = TemplateComponent(type='token', content='bos_token')
special_stopper: TemplateComponent = TemplateComponent(type='token', content='eos_token')


@dataclass
class Qwen2ConversationTemplate(ChatMLConversationTemplate):
separator: TemplateComponent = TemplateComponent(type='string', content='\n')
40 changes: 39 additions & 1 deletion tests/models/test_hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EmptyConversationTemplate,
Llama2ConversationTemplate,
Llama3ConversationTemplate,
Phi3ConversationTemplate,
EmptyConversationTemplateWithoutSpecialTokens,
)

Expand Down Expand Up @@ -87,6 +88,13 @@
)
]

CONVERSATION_SINGLETURN_PHI3_IDS = [
(
[1, 32006, 10876, 3888, 32007, 32010, 15043, 32007],
[32001, 6324, 29991, 32007, 32000]
)
]

CONVERSATION_MULTITURN = {
"system": "sysinfo",
"messages": [
Expand Down Expand Up @@ -153,6 +161,17 @@
)
]

CONVERSATION_MULTITURN_PHI3_IDS = [
(
[1, 32006, 10876, 3888, 32007, 32010, 15043, 32007],
[32001, 6324, 29991, 32007]
),
(
[32010, 1128, 526, 366, 29973, 32007],
[32001, 306, 29915, 29885, 1781, 29892, 3969, 29991, 32007, 32000]
)
]

test_encode_input = "Question: Which of the following is not true for myelinated nerve fibers: (A) Impulse through myelinated fibers is slower than non-myelinated fibers (B) Membrane currents are generated at nodes of Ranvier (C) Saltatory conduction of impulses is seen (D) Local anesthesia is effective only when the nerve is not covered by myelin sheath."
test_encode_output = [24361, 25, 9022, 286, 262, 1708, 318, 407, 2081, 329, 616, 417, 3898, 16384, 26742, 25, 357, 32, 8, 9855, 9615, 832, 616, 417, 3898, 26742, 318, 13611, 621, 1729, 12, 1820, 417, 3898, 26742, 357, 33, 8, 4942, 1671, 1531, 28629, 389, 7560, 379, 13760, 286, 23075, 49663, 357, 34, 8, 13754, 2870, 369, 11124, 286, 37505, 318, 1775, 357, 35, 8, 10714, 49592, 318, 4050, 691, 618, 262, 16384, 318, 407, 5017, 416, 616, 27176, 673, 776, 13]
test_decode_input = [24361, 25, 9022, 286, 262, 1708, 318, 407, 2081, 329, 616, 417, 3898, 16384, 26742, 25, 357, 32, 8, 9855, 9615, 832, 616, 417, 3898, 26742, 318, 13611, 621, 1729, 12, 1820, 417, 3898, 26742, 357, 33, 8, 4942, 1671, 1531, 28629, 389, 7560, 379, 13760, 286, 23075, 49663, 357, 34, 8, 13754, 2870, 369, 11124, 286, 37505, 318, 1775, 357, 35, 8, 10714, 49592, 318, 4050, 691, 618, 262, 16384, 318, 407, 5017, 416, 616, 27176, 673, 776, 13]
Expand Down Expand Up @@ -199,7 +218,10 @@ def _test_tokenize(

self.assertEqual(dataset.to_dict(), groundtruth_dataset)

model_args = ModelArguments(model_name_or_path=model_name)
model_args = ModelArguments(
model_name_or_path=model_name,
trust_remote_code=kwargs.get("trust_remote_code", False)
)
model = HFDecoderModel(model_args)

tokenized_dataset = model.tokenize(dataset, **kwargs)
Expand Down Expand Up @@ -326,6 +348,14 @@ def test_tokenize_conversation(self):
groundtruth_tokenized_dataset=make_gt_from_conversation_ids_batch([CONVERSATION_SINGLETURN_LLAMA3_IDS]),
conversation_template=Llama3ConversationTemplate()
)

self._test_tokenize(
model_name='microsoft/Phi-3-mini-4k-instruct',
groundtruth_dataset={"type": "conversation", "instances": [CONVERSATION_SINGLETURN]},
groundtruth_tokenized_dataset=make_gt_from_conversation_ids_batch([CONVERSATION_SINGLETURN_PHI3_IDS]),
conversation_template=Phi3ConversationTemplate(),
trust_remote_code=True
)


def test_tokenize_conversation_multiple(self):
Expand Down Expand Up @@ -392,6 +422,14 @@ def test_tokenize_conversation_multiple(self):
conversation_template=Llama3ConversationTemplate()
)

self._test_tokenize(
model_name='microsoft/Phi-3-mini-4k-instruct',
groundtruth_dataset={"type": "conversation", "instances": [CONVERSATION_MULTITURN]},
groundtruth_tokenized_dataset=make_gt_from_conversation_ids_batch([CONVERSATION_MULTITURN_PHI3_IDS]),
conversation_template=Phi3ConversationTemplate(),
trust_remote_code=True
)


def test_encode(self):
model_name = 'gpt2'
Expand Down
Loading