Skip to content

Commit 7687cea

Browse files
CRZbulabulaFerdinandZhong
authored andcommitted
[torch.compile] Adding torch compile annotations to some models (vllm-project#9614)
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
1 parent ccfddc0 commit 7687cea

File tree

6 files changed

+12
-0
lines changed

6 files changed

+12
-0
lines changed

vllm/model_executor/models/baichuan.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from transformers import PretrainedConfig
2727

2828
from vllm.attention import Attention, AttentionMetadata
29+
from vllm.compilation.decorators import support_torch_compile
2930
from vllm.config import CacheConfig, LoRAConfig
3031
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3132
get_tensor_model_parallel_world_size)
@@ -250,6 +251,7 @@ def forward(
250251
return hidden_states, residual
251252

252253

254+
@support_torch_compile
253255
class BaiChuanModel(nn.Module):
254256

255257
def __init__(self,

vllm/model_executor/models/bloom.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers import BloomConfig
2525

2626
from vllm.attention import Attention, AttentionMetadata
27+
from vllm.compilation.decorators import support_torch_compile
2728
from vllm.config import CacheConfig
2829
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
2930
get_tensor_model_parallel_world_size)
@@ -218,6 +219,7 @@ def forward(
218219
return output
219220

220221

222+
@support_torch_compile
221223
class BloomModel(nn.Module):
222224

223225
def __init__(

vllm/model_executor/models/commandr.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers import CohereConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.compilation.decorators import support_torch_compile
3132
from vllm.config import CacheConfig, LoRAConfig
3233
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3334
from vllm.model_executor.layers.activation import SiluAndMul
@@ -250,6 +251,7 @@ def forward(
250251
return hidden_states, residual
251252

252253

254+
@support_torch_compile
253255
class CohereModel(nn.Module):
254256

255257
def __init__(

vllm/model_executor/models/exaone.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch import nn
3030

3131
from vllm.attention import Attention, AttentionMetadata
32+
from vllm.compilation.decorators import support_torch_compile
3233
from vllm.config import CacheConfig, LoRAConfig
3334
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3435
get_tensor_model_parallel_world_size)
@@ -311,6 +312,7 @@ def forward(
311312
return hidden_states, residual
312313

313314

315+
@support_torch_compile
314316
class ExaoneModel(nn.Module):
315317

316318
def __init__(

vllm/model_executor/models/gemma.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers import GemmaConfig
2323

2424
from vllm.attention import Attention, AttentionMetadata
25+
from vllm.compilation.decorators import support_torch_compile
2526
from vllm.config import CacheConfig, LoRAConfig
2627
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2728
from vllm.logger import init_logger
@@ -239,6 +240,7 @@ def forward(
239240
return hidden_states, residual
240241

241242

243+
@support_torch_compile
242244
class GemmaModel(nn.Module):
243245

244246
def __init__(

vllm/model_executor/models/gpt2.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers import GPT2Config
2525

2626
from vllm.attention import Attention, AttentionMetadata
27+
from vllm.compilation.decorators import support_torch_compile
2728
from vllm.config import CacheConfig
2829
from vllm.distributed.parallel_state import (
2930
get_pp_group, get_tensor_model_parallel_world_size)
@@ -182,6 +183,7 @@ def forward(
182183
return hidden_states
183184

184185

186+
@support_torch_compile
185187
class GPT2Model(nn.Module):
186188

187189
def __init__(

0 commit comments

Comments
 (0)