@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141
141
if _key is None :
142
142
raise KeyError ("Key not found" )
143
143
value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144
- self .cache .push (_key , side = "front" ) # type: ignore
144
+ # NOTE: This puts an integer as key in cache, which breaks,
145
+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146
+ # self.cache.push(_key, side="front") # type: ignore
145
147
return value
146
148
147
149
def __contains__ (self , key : Sequence [int ]) -> bool :
@@ -168,7 +170,7 @@ def __init__(
168
170
eval_logits : Deque [List [float ]],
169
171
input_ids : npt .NDArray [np .intc ],
170
172
scores : npt .NDArray [np .single ],
171
- llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
173
+ llama_state : bytes ,
172
174
llama_state_size : int ,
173
175
):
174
176
self .eval_tokens = eval_tokens
@@ -1509,7 +1511,7 @@ def save_state(self) -> LlamaState:
1509
1511
eval_logits = self .eval_logits .copy (),
1510
1512
scores = self ._scores .copy (),
1511
1513
input_ids = self ._input_ids .copy (),
1512
- llama_state = llama_state_compact ,
1514
+ llama_state = bytes ( llama_state_compact ) ,
1513
1515
llama_state_size = n_bytes ,
1514
1516
)
1515
1517
@@ -1520,7 +1522,10 @@ def load_state(self, state: LlamaState) -> None:
1520
1522
self ._scores = state .scores .copy ()
1521
1523
self ._input_ids = state .input_ids .copy ()
1522
1524
state_size = state .llama_state_size
1523
- if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1525
+ LLamaStateArrayType = (llama_cpp .c_uint8 * state_size )
1526
+ llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1527
+
1528
+ if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
1524
1529
raise RuntimeError ("Failed to set llama state data" )
1525
1530
1526
1531
def n_ctx (self ) -> int :
0 commit comments