Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

enable LoRA for embedding models #821

Open
wants to merge 39 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
56b42c3
Initial draft to enable embedding task.
libinta Jan 24, 2025
b62d611
remove ENCODER_ONLY
libinta Jan 24, 2025
a647baa
Added support for embedding model with self attention without causal …
libinta Jan 27, 2025
46e1aad
Change set_attn_bias padding element from -math.inf to -3e38 as -math…
libinta Jan 27, 2025
2f74e6b
rewrite is_causal and add dbg msg
libinta Jan 27, 2025
99947c8
update maskoff value
libinta Jan 27, 2025
094294c
fix wrong base mask
libinta Jan 28, 2025
1c7416f
cleanup code
libinta Jan 29, 2025
c6cdae1
cleanup code
libinta Jan 29, 2025
8ac281b
cleanup code
libinta Jan 29, 2025
e72c2f0
Add pooler support for padded batch inputs for hpu with CLSPoll, Last…
libinta Jan 30, 2025
7c1c74b
add meanpool for padded input
libinta Jan 30, 2025
5c49ca1
revert bert change
libinta Jan 30, 2025
ae6fbe0
modify meanpool for padded input
libinta Jan 30, 2025
d65340a
write is_pooler function
libinta Jan 30, 2025
0c28519
fix is_causal logic
libinta Jan 31, 2025
1fe398f
Set is_causal based on attn_type
libinta Feb 1, 2025
c3a92f3
Set is_causal based on attn_type
libinta Feb 1, 2025
afe8bb3
enable lora embedding models on hpu
skaulintel Feb 3, 2025
55ae676
fix with warmup issue
libinta Feb 4, 2025
787700b
fix cpu test issue and format
libinta Feb 5, 2025
6f02b86
fix code format
libinta Feb 5, 2025
b97f7c6
Merge branch 'habana_main' into dev/enable_embedding_ace
libinta Feb 5, 2025
593ded0
fix hpu attn coding issue
libinta Feb 5, 2025
30f43b5
fix hpu_pooling_model_runner.py code format and add requirement-hpu w…
libinta Feb 5, 2025
7dc5239
Merge branch 'dev/enable_embedding_ace' into dev/skaul_enable_lora_embed
skaulintel Feb 6, 2025
c636da7
move create lora mask
skaulintel Feb 6, 2025
1185c2e
add support for batch padding
libinta Feb 5, 2025
82a6e70
Merge branch 'dev/enable_embedding_ace' into dev/skaul_enable_lora_embed
skaulintel Feb 7, 2025
53f94e0
Merge branch 'habana_main' into dev/enable_embedding_ace
kzawora-intel Feb 12, 2025
05ecf57
Merge branch 'dev/enable_embedding_ace' into dev/skaul_enable_lora_embed
skaulintel Feb 12, 2025
8d8f1b2
Merge branch 'habana_main' into dev/skaul_enable_lora_embed
skaulintel Feb 19, 2025
3dd63db
Update requirements-hpu.txt
skaulintel Feb 20, 2025
43ae76f
Merge branch 'habana_main' into dev/skaul_enable_lora_embed
skaulintel Feb 20, 2025
7bba2f3
restore requirements-hpu
skaulintel Feb 20, 2025
55695e3
remove intermediate tensor
skaulintel Feb 21, 2025
ed9b4b2
Update hpu_pooling_model_runner.py
skaulintel Feb 21, 2025
e673fbe
add back intermediate tensor
skaulintel Feb 21, 2025
665be55
Merge branch 'habana_main' into dev/skaul_enable_lora_embed
skaulintel Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 70 additions & 70 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,76 @@ def prepare_input_tensors(
lora_ids=lora_ids), \
sampling_metadata

def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still yapf errors in precommit, please fix

is_prompt: bool):
'''
This is a helper function to create the mask for lora computations.
Lora Mask is needed to ensure we match the correct lora weights for the
for the request.
For Prompt phase we have
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
lora_logits_mask with shape (batch_size, max_loras * max_rank)
For Decode phase we have both
lora_mask and lora_logits_mask with shape
(batch_size, max_loras * max_rank)
'''
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
lora_index = 0

if self.lora_config:
if is_prompt:
lora_mask = torch.zeros(
input_tokens.shape[0] * input_tokens.shape[1],
(self.lora_config.max_loras) *\
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(
input_tokens.shape[0], (self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

ones = torch.ones(input_tokens.shape[1],
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_row = i * input_tokens.shape[1]
end_row = start_row + input_tokens.shape[1]
start_col = lora_index * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[i, start_col:end_col] = logit_ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')
else:
lora_mask = torch.zeros(input_tokens.shape[0],
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_pos = lora_index * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
lora_mask[i, start_pos:end_pos] = ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask

return lora_mask, lora_logits_mask

def _seq_len(self, attn_metadata):
if attn_metadata.num_prefills != 0:
return attn_metadata.slot_mapping.size(1)
Expand Down Expand Up @@ -2218,75 +2288,6 @@ def prepare_model_input(
is_prompt=is_prompt,
virtual_engine=virtual_engine)

def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
'''
This is a helper function to create the mask for lora computations.
Lora Mask is needed to ensure we match the correct lora weights for the
for the request.
For Prompt phase we have
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
lora_logits_mask with shape (batch_size, max_loras * max_rank)
For Decode phase we have both
lora_mask and lora_logits_mask with shape
(batch_size, max_loras * max_rank)
'''
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
lora_index = 0

if self.lora_config:
if is_prompt:
lora_mask = torch.zeros(
input_tokens.shape[0] * input_tokens.shape[1],
(self.lora_config.max_loras) *\
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(
input_tokens.shape[0], (self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

ones = torch.ones(input_tokens.shape[1],
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_row = i * input_tokens.shape[1]
end_row = start_row + input_tokens.shape[1]
start_col = lora_index * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[i, start_col:end_col] = logit_ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')
else:
lora_mask = torch.zeros(input_tokens.shape[0],
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_pos = lora_index * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
lora_mask[i, start_pos:end_pos] = ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask

return lora_mask, lora_logits_mask

@torch.inference_mode()
def execute_model(
Expand Down Expand Up @@ -2338,7 +2339,6 @@ def execute_model(
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)

execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
Expand Down
16 changes: 16 additions & 0 deletions vllm/worker/hpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def execute_model(
if num_steps > 1:
raise ValueError(
"HPUPoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)

input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
Expand All @@ -60,6 +66,15 @@ def execute_model(
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
super()._check_config(batch_size, seq_len, is_prompt, warmup_mode)

lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)


num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
Expand All @@ -77,6 +92,7 @@ def execute_model(
"attn_metadata":
super().trim_attn_metadata(model_input.attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
}

if htorch.utils.internal.is_lazy():
Expand Down
Loading