From 4e1d03eb2846959323932ef1f2f6aa28c900f118 Mon Sep 17 00:00:00 2001 From: "Alexandr \"MrSteyk\" German" Date: Mon, 13 Feb 2023 17:05:00 +0500 Subject: [PATCH] token preproc --- python/Cargo.lock | 2 +- python/Cargo.toml | 2 +- python/src/lib.rs | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/Cargo.lock b/python/Cargo.lock index 8909cf1..f706dd8 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -227,7 +227,7 @@ dependencies = [ [[package]] name = "rwkv_rs" -version = "0.2.1" +version = "0.2.2" dependencies = [ "anyhow", "pyo3", diff --git a/python/Cargo.toml b/python/Cargo.toml index 4d12e0d..3a953d1 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rwkv_rs" -version = "0.2.1" +version = "0.2.2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/python/src/lib.rs b/python/src/lib.rs index 3939b34..f74be92 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -80,6 +80,12 @@ impl Rwkv { Ok(self.inner.rwkv().forward_raw(tokens.last().unwrap(), &mut state.inner)) } + pub fn forward_token_preproc(&self, token: usize, state: &mut State) -> PyResult<()> { + let x = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; + self.inner.rwkv().forward_raw_preproc(x, &mut state.inner); + Ok(()) + } + pub fn forward_token(&self, token: usize, state: &mut State) -> PyResult> { let x = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?; Ok(self.inner.rwkv().forward_raw(x, &mut state.inner))