Skip to content

Commit 790650f

Browse files
authored
[Bugfix] Fix fully sharded LoRA bug (vllm-project#10352)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 0376851 commit 790650f

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

vllm/lora/fully_sharded_layers.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,14 @@ class MergedColumnParallelLinearWithShardedLoRA(
165165
def slice_lora_a(
166166
self, lora_a: List[Union[torch.Tensor, None]]
167167
) -> List[Union[torch.Tensor, None]]:
168-
if lora_a[0] is None or lora_a[1] is None:
169-
return lora_a
168+
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
170169
output_shard_size = self.lora_a_stacked[0].shape[2]
171170
output_start_idx = self.tp_rank * output_shard_size
172171
lora_a = [
173-
lora_a[0][:,
174-
output_start_idx:output_start_idx + output_shard_size],
175-
lora_a[1][:,
176-
output_start_idx:output_start_idx + output_shard_size],
172+
lora_a[0][:, output_start_idx:output_start_idx +
173+
output_shard_size] if lora_a[0] is not None else None,
174+
lora_a[1][:, output_start_idx:output_start_idx +
175+
output_shard_size] if lora_a[1] is not None else None,
177176
]
178177
return lora_a
179178

@@ -261,14 +260,16 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
261260
def slice_lora_a(
262261
self, lora_a: List[Union[torch.Tensor, None]]
263262
) -> List[Union[torch.Tensor, None]]:
264-
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
265-
return lora_a
263+
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
266264
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
267265
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
268266
lora_a = [
269-
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
270-
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
271-
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
267+
lora_a[0][:, start_idx[0]:start_idx[0] +
268+
shard_size[0]] if lora_a[0] is not None else None,
269+
lora_a[1][:, start_idx[1]:start_idx[1] +
270+
shard_size[1]] if lora_a[1] is not None else None,
271+
lora_a[2][:, start_idx[2]:start_idx[2] +
272+
shard_size[2]] if lora_a[2] is not None else None,
272273
]
273274
return lora_a
274275

vllm/lora/layers.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -685,26 +685,27 @@ def slice_lora_a(
685685
def slice_lora_b(
686686
self, lora_b: List[Union[torch.Tensor, None]]
687687
) -> List[Union[torch.Tensor, None]]:
688-
if lora_b[0] is None or lora_b[1] is None:
689-
return lora_b
688+
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
690689
shard_size = self.output_dim
691690
start_idx = self.tp_rank * shard_size
692691
end_idx = (self.tp_rank + 1) * shard_size
693692
lora_b = [
694-
lora_b[0][:, start_idx:end_idx],
695-
lora_b[1][:, start_idx:end_idx],
693+
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
694+
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
696695
]
697696
return lora_b
698697

699698
def slice_bias(
700699
self, bias: List[Union[torch.Tensor,
701700
None]]) -> List[Union[torch.Tensor, None]]:
702-
if bias[0] is None or bias[1] is None:
703-
return bias
701+
# NOTE : each bias could be None.
704702
shard_size = self.output_dim
705703
start_idx = self.tp_rank * shard_size
706704
end_idx = (self.tp_rank + 1) * shard_size
707-
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
705+
bias = [
706+
bias[0][start_idx:end_idx] if bias[0] is not None else None,
707+
bias[1][start_idx:end_idx] if bias[1] is not None else None
708+
]
708709
return bias
709710

710711
def set_lora(

vllm/worker/worker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
232232
logger.info(
233233
"Memory profiling results: total_gpu_memory=%.2fGiB"
234234
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
235-
" memory_usage_post_profile=%.2fGib"
235+
" memory_usage_post_profile=%.2fGiB"
236236
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
237237
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
238238
(total_gpu_memory - free_memory_pre_profile) / (1024**3),

0 commit comments

Comments
 (0)