Commit 7687cea 1 parent ccfddc0 commit 7687cea Copy full SHA for 7687cea
File tree 6 files changed +12
-0
lines changed
vllm/model_executor/models
6 files changed +12
-0
lines changed Original file line number Diff line number Diff line change 26
26
from transformers import PretrainedConfig
27
27
28
28
from vllm .attention import Attention , AttentionMetadata
29
+ from vllm .compilation .decorators import support_torch_compile
29
30
from vllm .config import CacheConfig , LoRAConfig
30
31
from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
31
32
get_tensor_model_parallel_world_size )
@@ -250,6 +251,7 @@ def forward(
250
251
return hidden_states , residual
251
252
252
253
254
+ @support_torch_compile
253
255
class BaiChuanModel (nn .Module ):
254
256
255
257
def __init__ (self ,
Original file line number Diff line number Diff line change 24
24
from transformers import BloomConfig
25
25
26
26
from vllm .attention import Attention , AttentionMetadata
27
+ from vllm .compilation .decorators import support_torch_compile
27
28
from vllm .config import CacheConfig
28
29
from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
29
30
get_tensor_model_parallel_world_size )
@@ -218,6 +219,7 @@ def forward(
218
219
return output
219
220
220
221
222
+ @support_torch_compile
221
223
class BloomModel (nn .Module ):
222
224
223
225
def __init__ (
Original file line number Diff line number Diff line change 28
28
from transformers import CohereConfig
29
29
30
30
from vllm .attention import Attention , AttentionMetadata
31
+ from vllm .compilation .decorators import support_torch_compile
31
32
from vllm .config import CacheConfig , LoRAConfig
32
33
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
33
34
from vllm .model_executor .layers .activation import SiluAndMul
@@ -250,6 +251,7 @@ def forward(
250
251
return hidden_states , residual
251
252
252
253
254
+ @support_torch_compile
253
255
class CohereModel (nn .Module ):
254
256
255
257
def __init__ (
Original file line number Diff line number Diff line change 29
29
from torch import nn
30
30
31
31
from vllm .attention import Attention , AttentionMetadata
32
+ from vllm .compilation .decorators import support_torch_compile
32
33
from vllm .config import CacheConfig , LoRAConfig
33
34
from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
34
35
get_tensor_model_parallel_world_size )
@@ -311,6 +312,7 @@ def forward(
311
312
return hidden_states , residual
312
313
313
314
315
+ @support_torch_compile
314
316
class ExaoneModel (nn .Module ):
315
317
316
318
def __init__ (
Original file line number Diff line number Diff line change 22
22
from transformers import GemmaConfig
23
23
24
24
from vllm .attention import Attention , AttentionMetadata
25
+ from vllm .compilation .decorators import support_torch_compile
25
26
from vllm .config import CacheConfig , LoRAConfig
26
27
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
27
28
from vllm .logger import init_logger
@@ -239,6 +240,7 @@ def forward(
239
240
return hidden_states , residual
240
241
241
242
243
+ @support_torch_compile
242
244
class GemmaModel (nn .Module ):
243
245
244
246
def __init__ (
Original file line number Diff line number Diff line change 24
24
from transformers import GPT2Config
25
25
26
26
from vllm .attention import Attention , AttentionMetadata
27
+ from vllm .compilation .decorators import support_torch_compile
27
28
from vllm .config import CacheConfig
28
29
from vllm .distributed .parallel_state import (
29
30
get_pp_group , get_tensor_model_parallel_world_size )
@@ -182,6 +183,7 @@ def forward(
182
183
return hidden_states
183
184
184
185
186
+ @support_torch_compile
185
187
class GPT2Model (nn .Module ):
186
188
187
189
def __init__ (
You can’t perform that action at this time.
0 commit comments