Skip to content

Commit

Permalink
errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsteyk committed Feb 18, 2023
1 parent 89afcf4 commit bfa05f6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rwkv_rs"
version = "0.2.0"
version = "0.2.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -9,5 +9,6 @@ name = "rwkv_rs"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0"
pyo3 = { version = "0.18.0", features = ["extension-module", "anyhow"] }
rwkvk-rs = { path = "../" }
32 changes: 29 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@ use std::{pin::Pin, sync::Arc};

use pyo3::{prelude::*, types::PyList};

#[derive(Debug)]
enum RwkvError {
InvalidToken(usize),
EmptyArray,
}

impl std::fmt::Display for RwkvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RwkvError::InvalidToken(n) => f.write_fmt(format_args!("Invalid Token: {:?}", n)),
RwkvError::EmptyArray => f.write_str("Empty array"),
}
}
}

impl std::error::Error for RwkvError {
// ???
}

#[pyclass]
struct Rwkv {
inner: Arc<Pin<Box<rwkvk_rs::RwkvWrap<'static>>>>,
Expand Down Expand Up @@ -52,24 +71,31 @@ impl Rwkv {

pub fn forward_raw_batch(&self, tokens: &PyList, state: &mut State) -> PyResult<Vec<f32>> {
let tokens: Vec<Vec<f32>> = tokens.extract()?;
if tokens.len() == 0 {
return Err(anyhow::anyhow!(RwkvError::EmptyArray).into());
}
for token in tokens.iter().take(tokens.len()-1) {
self.inner.rwkv().forward_raw_preproc(token, &mut state.inner);
}
Ok(self.inner.rwkv().forward_raw(tokens.last().unwrap(), &mut state.inner))
}

pub fn forward_token(&self, token: usize, state: &mut State) -> PyResult<Vec<f32>> {
let x = self.inner.rwkv().emb.get(token).unwrap();
let x = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?;
Ok(self.inner.rwkv().forward_raw(x, &mut state.inner))
}

pub fn forward(&self, tokens: &PyList, state: &mut State) -> PyResult<Vec<f32>> {
let tokens: Vec<usize> = tokens.extract()?;
if tokens.len() == 0 {
return Err(anyhow::anyhow!(RwkvError::EmptyArray).into());
}
for &token in tokens.iter().take(tokens.len()-1) {
let token = self.inner.rwkv().emb.get(token).unwrap();
let token = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?;
self.inner.rwkv().forward_raw_preproc(token, &mut state.inner);
}
let token = self.inner.rwkv().emb.get(*tokens.last().unwrap()).unwrap();
let token = *tokens.last().unwrap();
let token = self.inner.rwkv().emb.get(token).ok_or(anyhow::anyhow!(RwkvError::InvalidToken(token)))?;
Ok(self.inner.rwkv().forward_raw(token, &mut state.inner))
}
}
Expand Down

0 comments on commit bfa05f6

Please # to comment.