Skip to content

Commit cda8c40

Browse files
janimoUbuntu
authored and
Ubuntu
committed
[Model] Support Cohere2ForCausalLM (Cohere R7B) (vllm-project#11203)
1 parent e1e22d8 commit cda8c40

File tree

5 files changed

+26
-4
lines changed

5 files changed

+26
-4
lines changed

docs/source/models/supported_models.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ Text Generation (``--task generate``)
118118
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
119119
- ✅︎
120120
- ✅︎
121-
* - :code:`CohereForCausalLM`
121+
* - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM`
122122
- Command-R
123-
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
123+
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
124124
- ✅︎
125125
- ✅︎
126126
* - :code:`DbrxForCausalLM`

tests/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class _HfExamplesInfo:
5353
# ChatGLMModel supports multimodal
5454
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
5555
trust_remote_code=True),
56+
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
57+
trust_remote_code=True),
5658
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
5759
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
5860
trust_remote_code=True),

tests/models/test_initialization.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4+
import transformers
45
from transformers import PretrainedConfig
56

67
from vllm import LLM
@@ -11,6 +12,9 @@
1112
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
1213
def test_can_initialize(model_arch):
1314
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
15+
if (model_arch == "Cohere2ForCausalLM"
16+
and transformers.__version__ < "4.48.0"):
17+
pytest.skip(reason="Model introduced in HF >= 4.48.0")
1418
if not model_info.is_available_online:
1519
pytest.skip("Model is not available online")
1620

vllm/model_executor/models/commandr.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from vllm.sequence import IntermediateTensors
4949

5050
from .interfaces import SupportsLoRA, SupportsPP
51-
from .utils import (is_pp_missing_parameter,
51+
from .utils import (extract_layer_index, is_pp_missing_parameter,
5252
make_empty_intermediate_tensors_factory, make_layers,
5353
maybe_prefix)
5454

@@ -171,12 +171,26 @@ def __init__(
171171
rope_scaling=self.rope_scaling,
172172
is_neox_style=False,
173173
)
174+
175+
sliding_window = getattr(config, "sliding_window", None)
176+
# Model v2 has sliding windows, v1 does not
177+
self.v1 = sliding_window is None
178+
179+
layer_idx = extract_layer_index(prefix)
180+
layer_has_sliding_window = (
181+
getattr(config, "sliding_window_pattern", False)
182+
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
183+
184+
self.sliding_window = (sliding_window
185+
if layer_has_sliding_window else None)
186+
174187
self.attn = Attention(self.num_heads,
175188
self.head_dim,
176189
self.scaling,
177190
num_kv_heads=self.num_kv_heads,
178191
cache_config=cache_config,
179192
quant_config=quant_config,
193+
per_layer_sliding_window=self.sliding_window,
180194
prefix=f"{prefix}.attn")
181195
if self.use_qk_norm:
182196
self.q_norm = LayerNorm(param_shape=(self.num_heads,
@@ -206,7 +220,8 @@ def forward(
206220
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
207221
if self.use_qk_norm:
208222
q, k = self._apply_qk_norm(q, k)
209-
q, k = self.rotary_emb(positions, q, k)
223+
if self.v1 or self.sliding_window:
224+
q, k = self.rotary_emb(positions, q, k)
210225
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
211226
output, _ = self.o_proj(attn_output)
212227
return output

vllm/model_executor/models/registry.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
4242
# ChatGLMModel supports multimodal
4343
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
44+
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
4445
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
4546
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
4647
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),

0 commit comments

Comments
 (0)