Skip to content

Commit

Permalink
py: update pyo3 and numpy to 0.21
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Apr 3, 2024
1 parent 1744064 commit b3a7fd9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions instant-clip-tokenizer-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ crate-type = ["cdylib"]

[dependencies]
instant-clip-tokenizer = { version = "0.1", features = ["ndarray"], path = "../instant-clip-tokenizer" }
numpy = "0.20"
pyo3 = { version = "0.20.0", features = ["extension-module"] }
numpy = "0.21"
pyo3 = { version = "0.21", features = ["extension-module"] }
19 changes: 11 additions & 8 deletions instant-clip-tokenizer-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::io::BufReader;
use numpy::{IntoPyArray, PyArray2};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;

#[pymodule]
#[pyo3(name = "instant_clip_tokenizer")]
fn instant_clip_tokenizer_py(_py: Python, m: &PyModule) -> PyResult<()> {
fn instant_clip_tokenizer_py(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Tokenizer>()?;
Ok(())
}
Expand Down Expand Up @@ -78,16 +79,18 @@ impl Tokenizer {
py: Python<'py>,
input: TokenizeBatchInput,
context_length: Option<usize>,
) -> PyResult<&'py PyArray2<u16>> {
) -> PyResult<Bound<'py, PyArray2<u16>>> {
let context_length = context_length.unwrap_or(77);
if context_length < 3 {
return Err(PyValueError::new_err("context_length is less than 3"));
}
let result = match input {
TokenizeBatchInput::Single(text) => self.inner.tokenize_batch([text], context_length),
TokenizeBatchInput::Multiple(texts) => self.inner.tokenize_batch(texts, context_length),
TokenizeBatchInput::Single(text) => self.inner.tokenize_batch([&*text], context_length),
TokenizeBatchInput::Multiple(texts) => self
.inner
.tokenize_batch(texts.iter().map(|s| &**s), context_length),
};
Ok(result.into_pyarray(py))
Ok(result.into_pyarray_bound(py))
}

/// Encode a `text` input as a sequence of tokens.
Expand Down Expand Up @@ -136,9 +139,9 @@ impl Tokenizer {
}

#[derive(FromPyObject)]
enum TokenizeBatchInput<'a> {
enum TokenizeBatchInput {
#[pyo3(transparent, annotation = "str")]
Single(&'a str),
Single(PyBackedStr),
#[pyo3(transparent, annotation = "list[str]")]
Multiple(Vec<&'a str>),
Multiple(Vec<PyBackedStr>),
}

0 comments on commit b3a7fd9

Please # to comment.