Skip to content

Commit d08cbf3

Browse files
committed
[Bugfix] Fix fully sharded LoRAs with Mixtral
- Changes ReplicatedLinearWithLoRA to always apply regardless of the fully sharded LoRA setting, since in both cases the layer needs to be replicated - Updates the existing mixtral all modeuls test to test both values of fully_sharded_loras (which includes a ReplicatedLayer [gate]) Signed-off-by: Jason Greene <jason.greene@redhat.com>
1 parent 47a0b61 commit d08cbf3

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tests/lora/test_mixtral.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
6262

6363

6464
@pytest.mark.parametrize("tp_size", [4])
65+
@pytest.mark.parametrize("fully_shard", [True, False])
6566
def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
66-
tp_size):
67+
tp_size, fully_shard):
6768
"""This LoRA model has all supported Mixtral target modules"""
6869

6970
if torch.cuda.device_count() < tp_size:
@@ -82,6 +83,7 @@ def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
8283
max_loras=4,
8384
distributed_executor_backend="ray",
8485
tensor_parallel_size=tp_size,
86+
fully_sharded_loras=fully_shard,
8587
max_lora_rank=32,
8688
)
8789

vllm/lora/layers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,9 @@ def forward(self, input_):
425425
if self.base_layer.skip_bias_add else None)
426426
return output, output_bias
427427

428+
# ReplicatedLinear should always be replaced, regardless of the fully
429+
# sharded LoRAs setting, because it is, by definition, copied per GPU.
428430
@classmethod
429-
@_not_fully_sharded_can_replace
430431
def can_replace_layer(
431432
cls,
432433
source_layer: nn.Module,

0 commit comments

Comments
 (0)