From fbff46074bdcd605a60becb87fcb2acca7db2f53 Mon Sep 17 00:00:00 2001 From: ivila <390810839@qq.com> Date: Fri, 21 Feb 2025 17:52:19 +0800 Subject: [PATCH] examples: add mnist-no-std --- Cargo.lock | 131 ++++++++++++++++---- examples/mnist-no-std/.gitignore | 1 + examples/mnist-no-std/Cargo.toml | 18 +++ examples/mnist-no-std/README.md | 86 +++++++++++++ examples/mnist-no-std/examples/infer.rs | 62 ++++++++++ examples/mnist-no-std/examples/train.rs | 89 ++++++++++++++ examples/mnist-no-std/samples/0.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/1.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/2.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/3.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/4.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/5.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/6.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/7.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/8.bin | Bin 0 -> 784 bytes examples/mnist-no-std/samples/9.bin | Bin 0 -> 784 bytes examples/mnist-no-std/src/inference.rs | 59 +++++++++ examples/mnist-no-std/src/lib.rs | 18 +++ examples/mnist-no-std/src/model/conv.rs | 48 ++++++++ examples/mnist-no-std/src/model/mlp.rs | 67 +++++++++++ examples/mnist-no-std/src/model/mod.rs | 67 +++++++++++ examples/mnist-no-std/src/train.rs | 154 ++++++++++++++++++++++++ examples/mnist-no-std/src/util.rs | 29 +++++ 23 files changed, 803 insertions(+), 26 deletions(-) create mode 100644 examples/mnist-no-std/.gitignore create mode 100644 examples/mnist-no-std/Cargo.toml create mode 100644 examples/mnist-no-std/README.md create mode 100644 examples/mnist-no-std/examples/infer.rs create mode 100644 examples/mnist-no-std/examples/train.rs create mode 100644 examples/mnist-no-std/samples/0.bin create mode 100644 examples/mnist-no-std/samples/1.bin create mode 100644 examples/mnist-no-std/samples/2.bin create mode 100644 examples/mnist-no-std/samples/3.bin create mode 100644 examples/mnist-no-std/samples/4.bin create mode 100644 examples/mnist-no-std/samples/5.bin create mode 100644 examples/mnist-no-std/samples/6.bin create mode 100644 examples/mnist-no-std/samples/7.bin create mode 100644 examples/mnist-no-std/samples/8.bin create mode 100644 examples/mnist-no-std/samples/9.bin create mode 100644 examples/mnist-no-std/src/inference.rs create mode 100644 examples/mnist-no-std/src/lib.rs create mode 100644 examples/mnist-no-std/src/model/conv.rs create mode 100644 examples/mnist-no-std/src/model/mlp.rs create mode 100644 examples/mnist-no-std/src/model/mod.rs create mode 100644 examples/mnist-no-std/src/train.rs create mode 100644 examples/mnist-no-std/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index 11cd146ec4..6961a3a36a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1118,9 +1118,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.14" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "jobserver", "libc", @@ -1816,6 +1816,36 @@ dependencies = [ "libloading", ] +[[package]] +name = "curl" +version = "0.4.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9fb4d13a1be2b58f14d60adba57c9834b78c62fd86c3e76a148f732686e9265" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe", + "openssl-sys", + "schannel", + "socket2", + "windows-sys 0.52.0", +] + +[[package]] +name = "curl-sys" +version = "0.4.80+curl-8.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55f7df2eac63200c3ab25bde3b2268ef2ee56af3d238e76d61f01c3c49bff734" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "windows-sys 0.52.0", +] + [[package]] name = "custom-csv-dataset" version = "0.17.0" @@ -2151,9 +2181,9 @@ dependencies = [ [[package]] name = "either" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" [[package]] name = "embassy-futures" @@ -2356,9 +2386,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.35" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" dependencies = [ "crc32fast", "miniz_oxide", @@ -3543,9 +3573,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "inout" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" dependencies = [ "generic-array", ] @@ -3728,7 +3758,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.8.0", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.9", ] [[package]] @@ -3742,6 +3772,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9b68e50e6e0b26f672573834882eb57759f6db9b3be2ea3c35c91188bb4eaa" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -3778,9 +3820,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "loop9" @@ -3973,9 +4015,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", "simd-adler32", @@ -3993,6 +4035,19 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mnist" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06eb594c0749d3a505304db3b7a9c9891630d32daad059e81bfd0d51846f95c" +dependencies = [ + "byteorder", + "curl", + "flate2", + "log", + "pbr", +] + [[package]] name = "mnist" version = "0.17.0" @@ -4014,6 +4069,19 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "mnist-no-std" +version = "0.17.0" +dependencies = [ + "burn", + "bytemuck", + "clap", + "image", + "mnist 0.6.0", + "rand 0.9.0", + "spin", +] + [[package]] name = "model" version = "0.6.0" @@ -4698,7 +4766,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.9", "smallvec", "windows-targets 0.52.6", ] @@ -4751,6 +4819,17 @@ dependencies = [ "hmac", ] +[[package]] +name = "pbr" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5827dfa0d69b6c92493d6c38e633bbaa5937c153d0d7c28bf12313f8c6d514" +dependencies = [ + "crossbeam-channel", + "libc", + "winapi", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -5651,7 +5730,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.1", + "rand_core 0.9.2", "zerocopy 0.8.20", ] @@ -5672,7 +5751,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.1", + "rand_core 0.9.2", ] [[package]] @@ -5686,9 +5765,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" +checksum = "7a509b1a2ffbe92afab0e55c8fd99dea1c280e8171bd2d88682bb20bc41cbc2c" dependencies = [ "getrandom 0.3.1", "zerocopy 0.8.20", @@ -5890,9 +5969,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f" dependencies = [ "bitflags 2.8.0", ] @@ -6019,9 +6098,9 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.9" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e75ec5e92c4d8aede845126adc388046234541629e76029599ed35a003c7ed24" +checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73" dependencies = [ "cc", "cfg-if", @@ -6677,9 +6756,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stacker" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e" +checksum = "d9156ebd5870ef293bfb43f91c7a74528d363ec0d424afe24160ed5a4343d08a" dependencies = [ "cc", "cfg-if", @@ -6862,9 +6941,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c65998313f8e17d0d553d28f91a0df93e4dbbbf770279c7bc21ca0f09ea1a1f6" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" dependencies = [ "filetime", "libc", diff --git a/examples/mnist-no-std/.gitignore b/examples/mnist-no-std/.gitignore new file mode 100644 index 0000000000..b5d14f8f1d --- /dev/null +++ b/examples/mnist-no-std/.gitignore @@ -0,0 +1 @@ +model.bin diff --git a/examples/mnist-no-std/Cargo.toml b/examples/mnist-no-std/Cargo.toml new file mode 100644 index 0000000000..3cdbdc013f --- /dev/null +++ b/examples/mnist-no-std/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "mnist-no-std" +edition.workspace = true +license.workspace = true +version.workspace = true +publish = false + +[dependencies] +burn = { path = "../../crates/burn", default-features = false, features = ["ndarray", "autodiff"] } +spin = { workspace = true } +bytemuck = { workspace = true, features = ["min_const_generics"] } + +[dev-dependencies] +mnist = { version = "0.6.0", features = ["download"] } +clap.workspace = true +rand = { workspace = true, default-features = true } +bytemuck = { workspace = true, features = ["min_const_generics"] } +image = { workspace = true } diff --git a/examples/mnist-no-std/README.md b/examples/mnist-no-std/README.md new file mode 100644 index 0000000000..97344148ec --- /dev/null +++ b/examples/mnist-no-std/README.md @@ -0,0 +1,86 @@ +# MNIST no-std + +This example demonstrates how to train and perform inference in a `no-std` +environment. + +## Running + +There are two examples in this crate: + +1. Training + + Trains a new model and exports it to the given path. + + ``` shell + cargo run --release --example train + ``` + + This example downloads the MNIST dataset, trains a new model, and outputs + the model to the given path(default: `model.bin`). + + You can run `cargo run --release --example train -- --help` for detailed + usage. + +2. Inference + + Loads a model from the given path, tests it with a given image, and prints + the inference result. + + ```shell + # cargo run --release --example infer -- --binary-path=samples/8.bin + cargo run --release --example infer -- -i ${binary_path} + ``` + + This command loads the model the model from the given + path(default: `model.bin`) and tests it with the given binary, and prints + the inference result. For convenience, you can use the sample binaries in + the `samples` folder. + + You can run `cargo run --release --example infer -- --help` for detailed + usage. + +## Design + +The crate is `no-std` and contains only logic related to training and inference. +It provides APIs that accept only primitive types as parameters to ensure +portability. It is the caller's responsibility to provide the data and control +the workflow. + +The crate consist of 3 modules: + +1. proto + + A module that contains the proto definitions shared between the crate and + its caller. It only includes primitive types to demonstrate portability. + +2. train + + A module that contains a simple `Trainer` and a public module named + `no_std_world`, which simulates a `no-std` environment and can be called + externally. + + It exports the following APIs: + + * initialize: Initializes a global trainer with a given random seed and + learning rate. + * train: Trains the model with the given data and return the loss and + accuracy for feedback. + * valid: Validates the model with the given data and return the loss and + accuracy for feedback. + * export: Exports the model as bytes so it can be persisted. + + You can refer to `examples/train.rs` for usage. + +3. inference + + A module that contains a simple `Model` and a public module named + `no_std_world`, which simulates a `no-std` environment, and can be called + externally. + + It exports the following APIs: + + * initialize: Initializes a global model with the provided record bytes. + * infer: Use the global model to perform inference with the given image and + return its inference result. + + You can refer to `examples/infer.rs` for usage. diff --git a/examples/mnist-no-std/examples/infer.rs b/examples/mnist-no-std/examples/infer.rs new file mode 100644 index 0000000000..661339a8a3 --- /dev/null +++ b/examples/mnist-no-std/examples/infer.rs @@ -0,0 +1,62 @@ +use clap::{Parser, ValueEnum}; +use mnist_no_std::{inference::no_std_world, proto::*}; + +#[derive(Clone, ValueEnum, Debug)] +enum InputType { + /// The input file is a binary(must be 784 bytes) + Binary, + /// The input file is an image(must with dimension of 28x28) + Image, +} + +/// Loads a model from the given path, tests it with a given binary, and prints +/// the inference result. +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// The path of the model. + #[arg(short, long, default_value = "model.bin")] + model: String, + /// The type of the input file + #[arg(short, long, value_enum, default_value_t = InputType::Binary)] + r#type: InputType, + /// The path of the input file. + #[arg(short, long)] + input: String, +} + +fn read_image_as_binary(args: &Args) -> Vec { + let path = std::path::absolute(&args.input).unwrap(); + match args.r#type { + InputType::Binary => { + println!("Load binary from \"{}\"", path.display()); + std::fs::read(path).unwrap() + } + InputType::Image => { + println!("Load image from \"{}\"", path.display()); + let img = image::open(&path) + .unwrap() + .resize_exact( + MNIST_IMAGE_WIDTH as u32, + MNIST_IMAGE_HEIGHT as u32, + image::imageops::FilterType::Nearest, + ) + .into_luma8(); + img.to_vec() + } + } +} + +fn main() { + let args = Args::parse(); + + let model_path = std::path::absolute(&args.model).unwrap(); + println!("Load model from \"{}\"", model_path.display()); + let record = std::fs::read(&model_path).unwrap(); + no_std_world::initialize(&record); + + let binary = read_image_as_binary(&args); + assert_eq!(binary.len(), MNIST_IMAGE_SIZE); + let result = no_std_world::infer(&binary); + println!("Inference result is: {}", result); +} diff --git a/examples/mnist-no-std/examples/train.rs b/examples/mnist-no-std/examples/train.rs new file mode 100644 index 0000000000..ba3dc816a7 --- /dev/null +++ b/examples/mnist-no-std/examples/train.rs @@ -0,0 +1,89 @@ +use clap::Parser; +use mnist_no_std::{proto::*, train::no_std_world}; +use rand::{seq::SliceRandom, Rng}; + +/// Trains a new model and exports it to the given path. +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + #[arg(short, long, default_value_t = 6)] + num_epochs: usize, + #[arg(short, long, default_value_t = 64)] + batch_size: usize, + #[arg(short, long, default_value_t = 0.0001)] + learning_rate: f64, + #[arg(short, long, default_value = "model.bin")] + output: String, +} + +fn convert_datasets(images: &[u8], labels: &[u8]) -> Vec<(MnistImage, u8)> { + let mut datasets: Vec<(MnistImage, u8)> = images + .chunks_exact(MNIST_IMAGE_SIZE) + .map(|v| v.try_into().unwrap()) + .zip(labels.iter().copied()) + .collect(); + datasets.shuffle(&mut rand::rng()); + datasets +} + +fn main() { + let args = Args::parse(); + // Download mnist data, keep the same URL with burn. + // Originally copy from burn/crates/burn-dataset/src/vision/mnist.rs + const BASE_URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + let data = mnist::MnistBuilder::new() + .base_url(BASE_URL) + .base_path( + std::env::temp_dir() + .join("example_mnist_no_std/") + .as_path() + .to_str() + .unwrap(), + ) + .download_and_extract() + .training_set_length(60_000) + .validation_set_length(10_000) + .test_set_length(0) + .finalize(); + // Initialize trainer + let seed: u64 = rand::rng().random(); + no_std_world::initialize(seed, args.learning_rate); + // Prepare datasets + let train_datasets = convert_datasets(&data.trn_img, &data.trn_lbl); + let valid_datasets = convert_datasets(&data.val_img, &data.val_lbl); + // Training loop, Originally inspired by burn/crates/custom-training-loop + // + // Normally there is no println in no_std, the caller must invoke functions + // step by step and receive feedback from no_std_world. For example, in + // TrustZone, there are two systems running on the same machine: one is + // Linux, and the other is TEEOS (the no_std world, bare metal env). The + // caller from Linux invokes functions via SMC (Secure Monitor Call) + // repeatedly, receives output through shared memory, and prints it to the + // screen. + for epoch in 1..args.num_epochs + 1 { + for (iteration, data) in train_datasets.chunks(args.batch_size).enumerate() { + let images: Vec = data.iter().map(|v| v.0).collect(); + let labels: Vec = data.iter().map(|v| v.1).collect(); + let output = no_std_world::train(bytemuck::cast_slice(images.as_slice()), &labels); + println!( + "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, iteration, output.loss, output.accuracy, + ); + } + + for (iteration, data) in valid_datasets.chunks(args.batch_size).enumerate() { + let images: Vec = data.iter().map(|v| v.0).collect(); + let labels: Vec = data.iter().map(|v| v.1).collect(); + let output = no_std_world::valid(bytemuck::cast_slice(images.as_slice()), &labels); + println!( + "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, iteration, output.loss, output.accuracy, + ); + } + } + // Export the model to the given path + let record = no_std_world::export(); + let output_path = std::path::absolute(&args.output).unwrap(); + println!("Export record to \"{}\"", output_path.display()); + std::fs::write(&output_path, &record).unwrap(); +} diff --git a/examples/mnist-no-std/samples/0.bin b/examples/mnist-no-std/samples/0.bin new file mode 100644 index 0000000000000000000000000000000000000000..6d4d6d9787efc0567fe877b13b2b92be2875a95c GIT binary patch literal 784 zcmZQz7&yQ^?e8gh9M-G;{{1%_yPW7zoN{^legoxJ=GtM=`4=el8%W$S#*p0h|1S`{ zx(6iCnI^3Q&ME@VZP1v!6@2v{T+s7e`4?xY+< zK>h@}9Kr<(r~dz4f~>FS_iq>%!_1p7IZ=C{tDZrPhbV{3O#u4jt}Vru7SoCg3mt&o=h literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/1.bin b/examples/mnist-no-std/samples/1.bin new file mode 100644 index 0000000000000000000000000000000000000000..e4727fea3730b4a23cf88bac56dc6d0d1a877df9 GIT binary patch literal 784 zcmZQz7+9e8FBhi+!v6%}lxzQpQ%d*CG@NoF|Lk$fox0DBOHSe64O~(T@&69vl#;*w%m{Vx!goWzUkxTF{={>{fJH~n89PPr%loN&q6{>ClW`|m6> W4t-qL|5f6UVqiVCfeV)$eZ&EutR|2E literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/2.bin b/examples/mnist-no-std/samples/2.bin new file mode 100644 index 0000000000000000000000000000000000000000..e9820a4466bd8d943b96a5afa7a811cb84d91eb1 GIT binary patch literal 784 zcmZQzpd~1&UH$+6v{-Cbto#cGx8yJ-K~h)yHvIitg&}A1>+dB69tO5^e|s?Ge1Bh( z1PWID`0I!v#~?2R61(^pn;ckm>G!|KxG?pBM0~#geQu9kilO1}-w8ORHhlX#p9i~~ zi{VlJ0G0du_e&(YYrPk4hRa?AI}KSovw(oQnt;H>&wu~!L=8X@ zE0{W~i+}%ahjBo(@7?eIFh1|?zaKnc93UO_{qG-37~k~w-~BK)h%Wj2*BZu8{_}St zj18js_WZp&PYuM%zx?elsvHB)-EV(?{{#Y%S563nR3U-*?N1vN^0n}A)vKY>*d=3?wvI1R>fc=(`A@By?^(z!q{MX<)17lU*W|sClo3Ehks!Fwm)zaAVOk)mcZmz z{#g!VLTJu27lojz{{GEIm19`>=XfX(G%h*(2W~S=iRP-We}Le{^WT5|aKZ%Pv|$Jk zFfc6q19uf%3dX4a^G6rPMyDJ8{zVsp^MU%{Y;;EZpD*Y_aQ=%o`EWKmV>4=CqDxcE F2LRI1OC|sS literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/5.bin b/examples/mnist-no-std/samples/5.bin new file mode 100644 index 0000000000000000000000000000000000000000..86165766adbc4669b9b2b6893647b1a3b048886b GIT binary patch literal 784 zcmZQz7+@f%udgrA7^;uy4DXpY|Ni}Z{oCKaXmZMZU;Y4PAOKCy9x8lw6%ZiFiL`hy zSiXDwZbgl&2se^AkjZ`SuO9=00)reAsyL9xw()NZ7ltTUfT#8EOMdK93^9N2%ixe= znD%!aE-8jrf4|im;*k9h^w{so(UqZETA}EMcSAi5qF$^(HGtXl*1g{_`9C$Nl0crb z#|fzHdJH)RhKj#GXBaQ}`xl#>5m4VAU_?y8qK|Vq$k1Qw(SjPHQhdwvzjqt3=mHDH LP8Gu?MF(*JxnOt? literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/6.bin b/examples/mnist-no-std/samples/6.bin new file mode 100644 index 0000000000000000000000000000000000000000..9b747bfc08c950f74b7fa81029e3d40e9d030e63 GIT binary patch literal 784 zcmZQzpe0~G^LGOd^Vz5U`BRHSuJq3?3mj4my?@GZNiod&ZiGut>hBR;QVi4Ja+aZ? zYN*N{{!IW2On>zG&!2ZF4(59L2q?v3_Vn-nSL+}Ce1iz2rZD5v)J?{qe>S-qzqG{oRze~TC<|6L5>Fw8m%Wq_%p zKz(ce?12bL{auYNC;a?-l`u%~#1~|zNZkXuMe_8YgS-qrCw`)Mk74KE&0IhoKFYGQ z{(eD~VvxD`_ehYXW%JQLf3F21TLxlF-uUz9&tD*(1GW-XS|GjqLr-_N2CC?w;Q;`3 Cylhwi literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/7.bin b/examples/mnist-no-std/samples/7.bin new file mode 100644 index 0000000000000000000000000000000000000000..6e67157fcf0d51926673c7ea31a4b5050fd5d0e0 GIT binary patch literal 784 zcmZQz7%m`W=ltn5D%gEUUKGM1C-(AxFb*k(#((d0aL8Hz{AY(lilO!2ZdM#}oM(Sn;gDje z`?n33l=trsmN?{u@BUkcLyGzIzuU?<iz?{Q(M$n3wD(d7i80$v}{cD z3#wZfgpd6@&#C0``tRQd-ROQ`5WM{N`iFmiUp2t&MKU)HXwTm#z8HEK+*Sdt0fKTQ z#bBmX(=%Xr>@0S;{P%|_s^qd4fB$}8q{qu1{q^rZ3skuVpgosj>*t>V*~gD6$I$li f-+zd8Pn6K47#O6Bd(NFZcm7;8x-JZrgHZqgPlA0? literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/samples/9.bin b/examples/mnist-no-std/samples/9.bin new file mode 100644 index 0000000000000000000000000000000000000000..bd51b5abf86bc4ee26d957b4770a969224c2437b GIT binary patch literal 784 zcmZQz7;=Db>OochHgH3@L`NzZWz@m;HYG?~N9Q zq|iH%J9nzw_&Wnbj=|-_|9=xX7~1~bSH>bauMg-9&h>v6V#+~9BK|(Y^aWIoY1!X8 voEiSTz=5geyYB~Ml0yRkEs|14 literal 0 HcmV?d00001 diff --git a/examples/mnist-no-std/src/inference.rs b/examples/mnist-no-std/src/inference.rs new file mode 100644 index 0000000000..486370d333 --- /dev/null +++ b/examples/mnist-no-std/src/inference.rs @@ -0,0 +1,59 @@ +use crate::{ + model::{mlp::MlpConfig, MnistConfig}, + proto::*, + util, +}; +use alloc::vec::Vec; +use burn::{ + prelude::*, + record::{FullPrecisionSettings, Recorder, RecorderError}, + tensor::cast::ToElement, +}; + +struct Model(crate::model::Model); + +impl Model { + fn new(device: &B::Device, data: Vec) -> Result { + let mlp_config = MlpConfig::new(); + let mnist_config = MnistConfig::new(mlp_config); + let model = crate::model::Model::new(&mnist_config, device); + + let recorder = burn::record::BinBytesRecorder::::new(); + let record = recorder.load(data, device)?; + Ok(Self(model.load_record(record))) + } + fn infer(&self, device: &B::Device, img: &MnistImage) -> u8 { + let tensor = util::image_to_tensor(device, img); + let output = self.0.forward(tensor); + let output = burn::tensor::activation::softmax(output, 1); + output.argmax(1).into_scalar().to_u8() + } +} + +pub mod no_std_world { + use super::Model; + use burn::backend::{ndarray::NdArrayDevice, NdArray}; + use spin::Mutex; + + type NoStdModel = Model; + + const DEVICE: NdArrayDevice = NdArrayDevice::Cpu; + static MODEL: Mutex> = Mutex::new(Option::None); + + pub fn initialize(record: &[u8]) { + let mut model = MODEL.lock(); + assert!(model.is_none(), "Model has been initialized"); + + model.replace(NoStdModel::new(&DEVICE, record.to_vec()).unwrap()); + } + + pub fn infer(image: &[u8]) -> u8 { + let model = MODEL.lock(); + assert!(!model.is_none()); + + model + .as_ref() + .expect("Model has not been initialized") + .infer(&DEVICE, image.try_into().unwrap()) + } +} diff --git a/examples/mnist-no-std/src/lib.rs b/examples/mnist-no-std/src/lib.rs new file mode 100644 index 0000000000..a3d8bbd85d --- /dev/null +++ b/examples/mnist-no-std/src/lib.rs @@ -0,0 +1,18 @@ +#![no_std] +extern crate alloc; +mod util; + +pub mod proto { + pub const MNIST_IMAGE_HEIGHT: usize = 28; + pub const MNIST_IMAGE_WIDTH: usize = 28; + pub const MNIST_IMAGE_SIZE: usize = MNIST_IMAGE_WIDTH * MNIST_IMAGE_HEIGHT; + pub type MnistImage = [u8; MNIST_IMAGE_SIZE]; + + pub struct Output { + pub loss: f32, + pub accuracy: f32, + } +} +pub mod inference; +mod model; +pub mod train; diff --git a/examples/mnist-no-std/src/model/conv.rs b/examples/mnist-no-std/src/model/conv.rs new file mode 100644 index 0000000000..1909e0be6e --- /dev/null +++ b/examples/mnist-no-std/src/model/conv.rs @@ -0,0 +1,48 @@ +// Originally copied from the burn/crates/burn-no-std-tests package + +use burn::{ + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct ConvBlock { + conv: nn::conv::Conv2d, + pool: nn::pool::MaxPool2d, + activation: nn::Gelu, +} + +#[derive(Config)] +pub struct ConvBlockConfig { + channels: [usize; 2], + #[config(default = "[3, 3]")] + kernel_size: [usize; 2], +} + +impl ConvBlock { + pub fn new(config: &ConvBlockConfig, device: &B::Device) -> Self { + let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(device); + let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) + .with_padding(nn::PaddingConfig2d::Same) + .init(); + let activation = nn::Gelu::new(); + + Self { + conv, + pool, + activation, + } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.conv.forward(input.clone()); + let x = self.pool.forward(x); + let x = self.activation.forward(x); + + (x + input) / 2.0 + } +} diff --git a/examples/mnist-no-std/src/model/mlp.rs b/examples/mnist-no-std/src/model/mlp.rs new file mode 100644 index 0000000000..8574d33733 --- /dev/null +++ b/examples/mnist-no-std/src/model/mlp.rs @@ -0,0 +1,67 @@ +// Originally copied from burn/crates/burn-no-std-tests package + +use alloc::vec::Vec; + +use burn::{ + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, +}; + +/// Configuration to create a [Multilayer Perceptron](Mlp) layer. +#[derive(Config)] +pub struct MlpConfig { + /// The number of layers. + #[config(default = 3)] + pub num_layers: usize, + /// The dropout rate. + #[config(default = 0.5)] + pub dropout: f64, + /// The size of each layer. + #[config(default = 256)] + pub d_model: usize, +} + +/// Multilayer Perceptron module. +#[derive(Module, Debug)] +pub struct Mlp { + linears: Vec>, + dropout: nn::Dropout, + activation: nn::Relu, +} + +impl Mlp { + /// Create the module from the given configuration. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + let mut linears = Vec::with_capacity(config.num_layers); + + for _ in 0..config.num_layers { + linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init(device)); + } + + Self { + linears, + dropout: nn::DropoutConfig::new(0.3).init(), + activation: nn::Relu::new(), + } + } + + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[batch_size, d_model]` + /// - output: `[batch_size, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let mut x = input; + + for linear in self.linears.iter() { + x = linear.forward(x); + x = self.dropout.forward(x); + x = self.activation.forward(x); + } + + x + } +} diff --git a/examples/mnist-no-std/src/model/mod.rs b/examples/mnist-no-std/src/model/mod.rs new file mode 100644 index 0000000000..e4a0abad03 --- /dev/null +++ b/examples/mnist-no-std/src/model/mod.rs @@ -0,0 +1,67 @@ +// Originally copied from burn/crates/burn-no-std-tests package + +pub mod conv; +pub mod mlp; + +use conv::{ConvBlock, ConvBlockConfig}; +use mlp::{Mlp, MlpConfig}; + +use burn::{ + config::Config, + module::Module, + nn, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Config)] +pub struct MnistConfig { + #[config(default = 42)] + pub seed: u64, + + pub mlp: MlpConfig, + + #[config(default = 784)] + pub input_size: usize, + + #[config(default = 10)] + pub output_size: usize, +} + +#[derive(Module, Debug)] +pub struct Model { + mlp: Mlp, + conv: ConvBlock, + input: nn::Linear, + output: nn::Linear, + num_classes: usize, +} + +impl Model { + pub fn new(config: &MnistConfig, device: &B::Device) -> Self { + let mlp = Mlp::new(&config.mlp, device); + let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(device); + let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(device); + let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1]), device); + + Self { + mlp, + conv, + output, + input, + num_classes: config.output_size, + } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let [batch_size, height, width] = input.dims(); + + let x = input.reshape([batch_size, 1, height, width]).detach(); + let x = self.conv.forward(x); + let x = x.reshape([batch_size, height * width]); + + let x = self.input.forward(x); + let x = self.mlp.forward(x); + + self.output.forward(x) + } +} diff --git a/examples/mnist-no-std/src/train.rs b/examples/mnist-no-std/src/train.rs new file mode 100644 index 0000000000..9c499a9e06 --- /dev/null +++ b/examples/mnist-no-std/src/train.rs @@ -0,0 +1,154 @@ +use crate::{ + model::{mlp::MlpConfig, MnistConfig, Model}, + proto::*, + util::{images_to_tensors, labels_to_tensors}, +}; +use alloc::vec::Vec; +use burn::{ + module::AutodiffModule, + nn::loss::CrossEntropyLoss, + optim::{adaptor::OptimizerAdaptor, Adam, AdamConfig, GradientsParams, Optimizer}, + prelude::*, + record::{FullPrecisionSettings, Recorder, RecorderError}, + tensor::{backend::AutodiffBackend, cast::ToElement}, +}; + +struct Trainer { + model: Model, + device: B::Device, + optim: OptimizerAdaptor, B>, + lr: f64, +} + +impl Trainer { + fn new(device: B::Device, seed: u64) -> Self { + let config_optimizer = AdamConfig::new(); + let model_config = MnistConfig::new(MlpConfig::new()).with_seed(seed); + + B::seed(model_config.seed); + + Self { + optim: config_optimizer.init(), + model: Model::new(&model_config, &device), + device, + lr: 1e-4, + } + } + + fn with_learning_rate(mut self, lr: f64) -> Self { + self.lr = lr; + self + } + + // Originally inspired by burn/examples/custom-training-loop + fn train(&mut self, images: &[MnistImage], labels: &[u8]) -> Output { + let images = images_to_tensors(&self.device, images); + let targets = labels_to_tensors(&self.device, labels); + let model = self.model.clone(); + + let output = model.forward(images); + let loss = + CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let accuracy = accuracy(output, targets); + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model. + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer. + self.model = self.optim.step(self.lr, model, grads); + + Output { + loss: loss.into_scalar().to_f32(), + accuracy, + } + } + + // Originally inspired by burn/examples/custom-training-loop + fn valid(&self, images: &[MnistImage], labels: &[u8]) -> Output { + // Get the model without autodiff. + let model_valid = self.model.valid(); + + let images = images_to_tensors(&self.device, images); + let targets = labels_to_tensors(&self.device, labels); + + let output = model_valid.forward(images); + let loss = + CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let accuracy = accuracy(output, targets); + + Output { + loss: loss.into_scalar().to_f32(), + accuracy, + } + } + + fn export(&self) -> Result, RecorderError> { + let recorder = burn::record::BinBytesRecorder::::new(); + recorder.record(self.model.clone().into_record(), ()) + } +} + +// Originally copy from burn/examples/custom-training-loop +/// Create out own accuracy metric calculation. +fn accuracy(output: Tensor, targets: Tensor) -> f32 { + let predictions = output.argmax(1).squeeze(1); + let num_predictions: usize = targets.dims().iter().product(); + let num_corrects = predictions.equal(targets).int().sum().into_scalar(); + + num_corrects.elem::() / num_predictions as f32 * 100.0 +} + +pub mod no_std_world { + use super::Trainer; + use crate::proto::*; + use alloc::vec::Vec; + use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; + use spin::Mutex; + + type NoStdTrainer = Trainer>; + + const DEVICE: NdArrayDevice = NdArrayDevice::Cpu; + static TRAINER: Mutex> = Mutex::new(Option::None); + + pub fn initialize(seed: u64, lr: f64) { + let mut trainer = TRAINER.lock(); + assert!(trainer.is_none(), "Trainer has been initialized"); + + trainer.replace(NoStdTrainer::new(DEVICE, seed).with_learning_rate(lr)); + } + + pub fn train(images: &[u8], labels: &[u8]) -> Output { + assert!(images.len() % MNIST_IMAGE_SIZE == 0); + let images: &[MnistImage] = bytemuck::cast_slice(images); + + let mut trainer = TRAINER.lock(); + + trainer + .as_mut() + .expect("Trainer has not been initialized") + .train(images, labels) + } + + pub fn valid(images: &[u8], labels: &[u8]) -> Output { + assert!(images.len() % MNIST_IMAGE_SIZE == 0); + let images: &[MnistImage] = bytemuck::cast_slice(images); + + let trainer = TRAINER.lock(); + + trainer + .as_ref() + .expect("Trainer has not been initialized") + .valid(images, labels) + } + + pub fn export() -> Vec { + let trainer = TRAINER.lock(); + + trainer + .as_ref() + .expect("Trainer has not been initialized") + .export() + .unwrap() + } +} diff --git a/examples/mnist-no-std/src/util.rs b/examples/mnist-no-std/src/util.rs new file mode 100644 index 0000000000..baf004faf6 --- /dev/null +++ b/examples/mnist-no-std/src/util.rs @@ -0,0 +1,29 @@ +use crate::proto::{MnistImage, MNIST_IMAGE_HEIGHT, MNIST_IMAGE_WIDTH}; + +use burn::prelude::*; + +// Convert an image into Tensor +// Originally copy from burn/examples/mnist-inference-web +pub fn image_to_tensor(device: &B::Device, image: &MnistImage) -> Tensor { + let tensor = TensorData::from(image.as_slice()).convert::(); + let tensor = Tensor::::from_data(tensor, device); + let tensor = tensor.reshape([1, MNIST_IMAGE_WIDTH, MNIST_IMAGE_HEIGHT]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + ((tensor / 255) - 0.1307) / 0.3081 +} + +pub fn images_to_tensors(device: &B::Device, images: &[MnistImage]) -> Tensor { + let tensors = images.iter().map(|v| image_to_tensor(device, v)).collect(); + Tensor::cat(tensors, 0) +} + +pub fn labels_to_tensors(device: &B::Device, labels: &[u8]) -> Tensor { + let targets = labels + .iter() + .map(|item| Tensor::::from_data([(*item as i64).elem::()], device)) + .collect(); + Tensor::cat(targets, 0) +}