@@ -165,15 +165,14 @@ class MergedColumnParallelLinearWithShardedLoRA(
165
165
def slice_lora_a (
166
166
self , lora_a : List [Union [torch .Tensor , None ]]
167
167
) -> List [Union [torch .Tensor , None ]]:
168
- if lora_a [0 ] is None or lora_a [1 ] is None :
169
- return lora_a
168
+ #NOTE: lora_a contains 2 subloras, and each sublora could be None.
170
169
output_shard_size = self .lora_a_stacked [0 ].shape [2 ]
171
170
output_start_idx = self .tp_rank * output_shard_size
172
171
lora_a = [
173
- lora_a [0 ][:,
174
- output_start_idx : output_start_idx + output_shard_size ] ,
175
- lora_a [1 ][:,
176
- output_start_idx : output_start_idx + output_shard_size ] ,
172
+ lora_a [0 ][:, output_start_idx : output_start_idx +
173
+ output_shard_size ] if lora_a [ 0 ] is not None else None ,
174
+ lora_a [1 ][:, output_start_idx : output_start_idx +
175
+ output_shard_size ] if lora_a [ 1 ] is not None else None ,
177
176
]
178
177
return lora_a
179
178
@@ -261,14 +260,16 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
261
260
def slice_lora_a (
262
261
self , lora_a : List [Union [torch .Tensor , None ]]
263
262
) -> List [Union [torch .Tensor , None ]]:
264
- if lora_a [0 ] is None or lora_a [1 ] is None or lora_a [2 ] is None :
265
- return lora_a
263
+ # NOTE: lora_a contains 3 subloras, and each sublora could be None.
266
264
shard_size = [self .lora_a_stacked [i ].shape [2 ] for i in range (3 )]
267
265
start_idx = [self .tp_rank * shard_size [i ] for i in range (3 )]
268
266
lora_a = [
269
- lora_a [0 ][:, start_idx [0 ]:start_idx [0 ] + shard_size [0 ]],
270
- lora_a [1 ][:, start_idx [1 ]:start_idx [1 ] + shard_size [1 ]],
271
- lora_a [2 ][:, start_idx [2 ]:start_idx [2 ] + shard_size [2 ]],
267
+ lora_a [0 ][:, start_idx [0 ]:start_idx [0 ] +
268
+ shard_size [0 ]] if lora_a [0 ] is not None else None ,
269
+ lora_a [1 ][:, start_idx [1 ]:start_idx [1 ] +
270
+ shard_size [1 ]] if lora_a [1 ] is not None else None ,
271
+ lora_a [2 ][:, start_idx [2 ]:start_idx [2 ] +
272
+ shard_size [2 ]] if lora_a [2 ] is not None else None ,
272
273
]
273
274
return lora_a
274
275
0 commit comments