Skip to content

Commit 628e3fb

Browse files
authored
Merge pull request ggml-org#370 from Okabintaro/fix-state-pickle
fix: Make LLamaState pickleable for disk cache
2 parents 04d9218 + 5eb4ebb commit 628e3fb

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

Diff for: llama_cpp/llama.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
self.cache.push(_key, side="front") # type: ignore
144+
# NOTE: This puts an integer as key in cache, which breaks,
145+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+
# self.cache.push(_key, side="front") # type: ignore
145147
return value
146148

147149
def __contains__(self, key: Sequence[int]) -> bool:
@@ -168,7 +170,7 @@ def __init__(
168170
eval_logits: Deque[List[float]],
169171
input_ids: npt.NDArray[np.intc],
170172
scores: npt.NDArray[np.single],
171-
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
173+
llama_state: bytes,
172174
llama_state_size: int,
173175
):
174176
self.eval_tokens = eval_tokens
@@ -1509,7 +1511,7 @@ def save_state(self) -> LlamaState:
15091511
eval_logits=self.eval_logits.copy(),
15101512
scores=self._scores.copy(),
15111513
input_ids=self._input_ids.copy(),
1512-
llama_state=llama_state_compact,
1514+
llama_state=bytes(llama_state_compact),
15131515
llama_state_size=n_bytes,
15141516
)
15151517

@@ -1520,7 +1522,10 @@ def load_state(self, state: LlamaState) -> None:
15201522
self._scores = state.scores.copy()
15211523
self._input_ids = state.input_ids.copy()
15221524
state_size = state.llama_state_size
1523-
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
1525+
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1526+
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
1527+
1528+
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
15241529
raise RuntimeError("Failed to set llama state data")
15251530

15261531
def n_ctx(self) -> int:

0 commit comments

Comments
 (0)