@@ -40,10 +40,9 @@ def __init__(self, model_config: ModelConfig):
40
40
}
41
41
42
42
@staticmethod
43
- def _find_list (arr : array , target : array , start_idx : int ) -> int :
43
+ def _find_array (arr : array , target : array , start_idx : int ) -> int :
44
44
"""
45
- Find the first starting index where the search_list appears
46
- as a consecutive subsequence in main_list.
45
+ Find the first occurrence of target in arr starting from start_idx.
47
46
48
47
Args:
49
48
arr: The array to search within
@@ -55,25 +54,14 @@ def _find_list(arr: array, target: array, start_idx: int) -> int:
55
54
"""
56
55
if start_idx < 0 :
57
56
raise ValueError ("start_idx must be non-negative" )
58
-
59
- found_index = - 1
60
-
61
- # Handle edge cases
62
57
if not target or not arr :
63
- return found_index
58
+ raise ValueError ( "Empty arr or target not allowed" )
64
59
65
- # Length of lists
66
- arr_len = len (arr )
67
60
target_len = len (target )
68
-
69
- # Iterate through possible starting positions
70
- for i in range (start_idx , arr_len - target_len + 1 ):
71
- # Check if the subsequence matches
61
+ for i in range (start_idx , len (arr ) - target_len + 1 ):
72
62
if arr [i :i + target_len ] == target :
73
- found_index = i
74
- break
75
-
76
- return found_index
63
+ return i
64
+ return - 1
77
65
78
66
def _get_instruction_len (self , prompt_token_ids : array ) -> bool :
79
67
"""
@@ -102,20 +90,22 @@ def tokens_to_ids(tokens: list[str]) -> List[int]:
102
90
103
91
# Find the user pattern in the prompt.
104
92
user_token_ids = tokens_to_ids (["▁<" , "|" , "user" , "|" , ">" , "<0x0A>" ])
105
- found_user_pattern = (__class__ ._find_list (prompt_token_ids ,
106
- user_token_ids ,
107
- start_idx = 1 ) == 1 )
93
+ found_user_pattern = (__class__ ._find_array (prompt_token_ids ,
94
+ user_token_ids ,
95
+ start_idx = 1 ) == 1 )
108
96
109
97
# Find the embed pattern in the prompt.
110
98
if found_user_pattern :
99
+ # If user pattern is found, that means there should be
100
+ # a newline token before the embed pattern.
111
101
embed_token_ids = tokens_to_ids (
112
102
["<0x0A>" , "<" , "|" , "embed" , "|" , ">" , "<0x0A>" ])
113
103
else :
114
104
embed_token_ids = tokens_to_ids (
115
105
["▁<" , "|" , "embed" , "|" , ">" , "<0x0A>" ])
116
- found_embed_pattern_idx = __class__ ._find_list (prompt_token_ids ,
117
- embed_token_ids ,
118
- start_idx = 1 )
106
+ found_embed_pattern_idx = __class__ ._find_array (prompt_token_ids ,
107
+ embed_token_ids ,
108
+ start_idx = 1 )
119
109
120
110
if found_embed_pattern_idx != - 1 :
121
111
instruction_len = found_embed_pattern_idx + len (embed_token_ids )
0 commit comments