Skip to content

Commit 4941376

Browse files
committed
Improve _find_list
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
1 parent 1aec4d3 commit 4941376

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

tests/models/embedding/language/test_gritlm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ def _arr(arr):
2626
return array("i", arr)
2727

2828

29-
def test_find_list():
29+
def test_find_array():
3030
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
3131

32-
assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=0) == 3
33-
assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=1) == 3
34-
assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=5) == -1
35-
assert GritLMPooler._find_list(arr, _arr([3, 5]), start_idx=0) == -1
32+
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
33+
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
34+
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
35+
assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
3636

3737
with pytest.raises(ValueError):
38-
GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=-1)
38+
GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
3939

4040

4141
@pytest.fixture(scope="module")

vllm/model_executor/models/gritlm.py

+14-24
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@ def __init__(self, model_config: ModelConfig):
4040
}
4141

4242
@staticmethod
43-
def _find_list(arr: array, target: array, start_idx: int) -> int:
43+
def _find_array(arr: array, target: array, start_idx: int) -> int:
4444
"""
45-
Find the first starting index where the search_list appears
46-
as a consecutive subsequence in main_list.
45+
Find the first occurrence of target in arr starting from start_idx.
4746
4847
Args:
4948
arr: The array to search within
@@ -55,25 +54,14 @@ def _find_list(arr: array, target: array, start_idx: int) -> int:
5554
"""
5655
if start_idx < 0:
5756
raise ValueError("start_idx must be non-negative")
58-
59-
found_index = -1
60-
61-
# Handle edge cases
6257
if not target or not arr:
63-
return found_index
58+
raise ValueError("Empty arr or target not allowed")
6459

65-
# Length of lists
66-
arr_len = len(arr)
6760
target_len = len(target)
68-
69-
# Iterate through possible starting positions
70-
for i in range(start_idx, arr_len - target_len + 1):
71-
# Check if the subsequence matches
61+
for i in range(start_idx, len(arr) - target_len + 1):
7262
if arr[i:i + target_len] == target:
73-
found_index = i
74-
break
75-
76-
return found_index
63+
return i
64+
return -1
7765

7866
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
7967
"""
@@ -102,20 +90,22 @@ def tokens_to_ids(tokens: list[str]) -> List[int]:
10290

10391
# Find the user pattern in the prompt.
10492
user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
105-
found_user_pattern = (__class__._find_list(prompt_token_ids,
106-
user_token_ids,
107-
start_idx=1) == 1)
93+
found_user_pattern = (__class__._find_array(prompt_token_ids,
94+
user_token_ids,
95+
start_idx=1) == 1)
10896

10997
# Find the embed pattern in the prompt.
11098
if found_user_pattern:
99+
# If user pattern is found, that means there should be
100+
# a newline token before the embed pattern.
111101
embed_token_ids = tokens_to_ids(
112102
["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
113103
else:
114104
embed_token_ids = tokens_to_ids(
115105
["▁<", "|", "embed", "|", ">", "<0x0A>"])
116-
found_embed_pattern_idx = __class__._find_list(prompt_token_ids,
117-
embed_token_ids,
118-
start_idx=1)
106+
found_embed_pattern_idx = __class__._find_array(prompt_token_ids,
107+
embed_token_ids,
108+
start_idx=1)
119109

120110
if found_embed_pattern_idx != -1:
121111
instruction_len = found_embed_pattern_idx + len(embed_token_ids)

0 commit comments

Comments
 (0)