diff --git a/python/Cargo.lock b/python/Cargo.lock index f706dd8..0d6e7d3 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -227,7 +227,7 @@ dependencies = [ [[package]] name = "rwkv_rs" -version = "0.2.2" +version = "0.2.3" dependencies = [ "anyhow", "pyo3", diff --git a/python/Cargo.toml b/python/Cargo.toml index 3a953d1..3634e74 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rwkv_rs" -version = "0.2.2" +version = "0.2.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/python/src/lib.rs b/python/src/lib.rs index f74be92..3ae875b 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -49,13 +49,21 @@ impl State { ]; Ok(Self { inner: state }) } + + pub fn copy(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } } #[pymethods] impl Rwkv { #[new] pub fn from_path(path: &str) -> PyResult { - Ok(Self { inner: Arc::new(rwkvk_rs::RwkvWrap::new_from_path(path)?) }) + Ok(Self { + inner: Arc::new(rwkvk_rs::RwkvWrap::new_from_path(path)?), + }) } pub fn forward_raw_preproc(&self, x: &PyList, state: &mut State) -> PyResult<()> { @@ -74,20 +82,35 @@ impl Rwkv { if tokens.len() == 0 { return Err(anyhow::anyhow!(RwkvError::EmptyArray).into()); } - for token in tokens.iter().take(tokens.len()-1) { - self.inner.rwkv().forward_raw_preproc(token, &mut state.inner); + for token in tokens.iter().take(tokens.len() - 1) { + self.inner + .rwkv() + .forward_raw_preproc(token, &mut state.inner); } - Ok(self.inner.rwkv().forward_raw(tokens.last().unwrap(), &mut state.inner)) + Ok(self + .inner + .rwkv() + .forward_raw(tokens.last().unwrap(), &mut state.inner)) } pub fn forward_token_preproc(&self, token: usize, state: &mut State) -> PyResult<()> { - let x = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; + let x = self + .inner + .rwkv() + .emb + .get(token) + .ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; self.inner.rwkv().forward_raw_preproc(x, &mut state.inner); Ok(()) } pub fn forward_token(&self, token: usize, state: &mut State) -> PyResult> { - let x = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; + let x = self + .inner + .rwkv() + .emb + .get(token) + .ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; Ok(self.inner.rwkv().forward_raw(x, &mut state.inner)) } @@ -96,12 +119,24 @@ impl Rwkv { if tokens.len() == 0 { return Err(anyhow::anyhow!(RwkvError::EmptyArray).into()); } - for &token in tokens.iter().take(tokens.len()-1) { - let token = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; - self.inner.rwkv().forward_raw_preproc(token, &mut state.inner); + for &token in tokens.iter().take(tokens.len() - 1) { + let token = self + .inner + .rwkv() + .emb + .get(token) + .ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; + self.inner + .rwkv() + .forward_raw_preproc(token, &mut state.inner); } let token = *tokens.last().unwrap(); - let token = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; + let token = self + .inner + .rwkv() + .emb + .get(token) + .ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; Ok(self.inner.rwkv().forward_raw(token, &mut state.inner)) } }