Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

implement bincode for tensor #209

Merged
merged 2 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ kornia = { path = "crates/kornia", version = "0.1.8-rc.1" }
# dev dependencies for workspace
argh = "0.1"
approx = "0.5"
bincode = { version = "2.0.0-rc.3", features = ["serde"] }
criterion = "0.5"
env_logger = "0.11"
faer = "0.20.1"
Expand Down
10 changes: 9 additions & 1 deletion crates/kornia-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,13 @@ version.workspace = true

[dependencies]
num-traits = { workspace = true }
serde = { workspace = true }
serde = { workspace = true, optional = true }
bincode = { workspace = true, optional = true }
thiserror = { workspace = true }

[features]
serde = ["dep:serde"]
bincode = ["dep:bincode"]

[dev-dependencies]
serde_json = "1"
59 changes: 59 additions & 0 deletions crates/kornia-tensor/src/bincode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use crate::{
allocator::{CpuAllocator, TensorAllocator},
storage::TensorStorage,
Tensor,
};

impl<T, const N: usize, A: TensorAllocator + 'static> bincode::enc::Encode for Tensor<T, N, A>
where
T: bincode::enc::Encode + 'static,
{
fn encode<E: bincode::enc::Encoder>(
&self,
encoder: &mut E,
) -> Result<(), bincode::error::EncodeError> {
bincode::Encode::encode(&self.shape, encoder)?;
bincode::Encode::encode(&self.strides, encoder)?;
bincode::Encode::encode(&self.storage.as_slice(), encoder)?;
Ok(())
}
}

impl<T, const N: usize> bincode::de::Decode for Tensor<T, N, CpuAllocator>
where
T: bincode::de::Decode + 'static,
{
fn decode<D: bincode::de::Decoder>(
decoder: &mut D,
) -> Result<Self, bincode::error::DecodeError> {
let shape = bincode::Decode::decode(decoder)?;
let strides = bincode::Decode::decode(decoder)?;
let data = bincode::Decode::decode(decoder)?;
Ok(Self {
shape,
strides,
storage: TensorStorage::from_vec(data, CpuAllocator),
})
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_bincode() -> Result<(), Box<dyn std::error::Error>> {
let tensor = Tensor::<u8, 2, CpuAllocator>::from_shape_vec(
[2, 3],
vec![1, 2, 3, 4, 5, 6],
CpuAllocator,
)?;
let mut serialized = vec![0u8; 100];
let config = bincode::config::standard();
let length = bincode::encode_into_slice(&tensor, &mut serialized, config)?;
let deserialized: (Tensor<u8, 2, CpuAllocator>, usize) =
bincode::decode_from_slice(&serialized[..length], config)?;
assert_eq!(tensor.as_slice(), deserialized.0.as_slice());
Ok(())
}
}
5 changes: 5 additions & 0 deletions crates/kornia-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
/// allocator module containing the memory management utilities.
pub mod allocator;

/// bincode module containing the serialization and deserialization utilities.
#[cfg(feature = "bincode")]
pub mod bincode;

/// tensor module containing the tensor and storage implementations.
pub mod tensor;

/// serde module containing the serialization and deserialization utilities.
#[cfg(feature = "serde")]
pub mod serde;

/// storage module containing the storage implementations.
Expand Down
16 changes: 16 additions & 0 deletions crates/kornia-tensor/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,19 @@ where
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::allocator::CpuAllocator;

#[test]
fn test_serde() -> Result<(), Box<dyn std::error::Error>> {
let data = vec![1, 2, 3, 4, 5, 6];
let tensor = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 3], data, CpuAllocator)?;
let serialized = serde_json::to_string(&tensor)?;
let deserialized: Tensor<u8, 2, CpuAllocator> = serde_json::from_str(&serialized)?;
assert_eq!(tensor.as_slice(), deserialized.as_slice());
Ok(())
}
}
5 changes: 5 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ clean:
test name='':
@cargo test {{ name }}

# Test the code with all features
test-all:
@cargo test --all-features


# ------------------------------------------------------------------------------
# Recipes for the kornia-py project
# ------------------------------------------------------------------------------
Expand Down
Loading