From 0452053372fe375660ce42dbaa76608803725e1b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Sep 2022 11:37:32 -0400 Subject: [PATCH 1/5] Add docs to the Tensor struct --- burn-tensor/src/tensor/tensor.rs | 236 +++++++++++++++++++++++++------ 1 file changed, 192 insertions(+), 44 deletions(-) diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index 9a01d007ae..24513681f8 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -8,6 +8,7 @@ use crate::BoolTensor; use crate::Element; use std::convert::TryInto; +/// A tensor or a *n-dimensional* array. #[derive(Debug, Clone)] pub struct Tensor { pub(crate) value: B::TensorPrimitive, @@ -17,46 +18,102 @@ impl Tensor where B: Backend, { - pub fn new(tensor: B::TensorPrimitive) -> Self { + pub(crate) fn new(tensor: B::TensorPrimitive) -> Self { Self { value: tensor } } + /// Reshape the tensor to have the given shape. + /// + /// # Panics + /// + /// If the tensor can not be reshape to the given shape. pub fn reshape(&self, shape: Shape) -> Tensor { Tensor::new(self.value.reshape(shape)) } + /// Returns a new tensor on the given device. pub fn to_device(&self, device: B::Device) -> Self { Self::new(self.value.to_device(device)) } + /// Returns the device of the current tensor. + pub fn device(&self) -> B::Device { + self.value.device() + } + + /// Apply element wise exponential operation. + /// + /// `y = e^x` pub fn exp(&self) -> Self { Self::new(self.value.exp()) } + /// Apply element wise natural log operation *ln*. + /// + /// `y = log(x)` pub fn log(&self) -> Self { Self::new(self.value.log()) } - pub fn device(&self) -> B::Device { - self.value.device() - } - + /// Returns the shape of the current tensor. pub fn shape(&self) -> &Shape { self.value.shape() } + /// Returns the data of the current tensor. pub fn into_data(self) -> Data { self.value.into_data() } + /// Returns the data of the current tensor without taking ownership. pub fn to_data(&self) -> Data { self.value.to_data() } + /// Create a tensor from the given data. + pub fn from_data(data: Data) -> Self { + let tensor = B::from_data(data, B::Device::default()); + Tensor::new(tensor) + } + + /// Create a tensor from the given data on the given device. + pub fn from_data_device(data: Data, device: B::Device) -> Self { + let tensor = B::from_data(data, device); + Tensor::new(tensor) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. pub fn zeros_like(&self) -> Self { Tensor::new(B::zeros(self.shape().clone(), self.value.device())) } + /// Returns a new tensor with the same shape and device as the current tensor filled with ones. + pub fn ones_like(&self) -> Self { + Tensor::new(B::ones(self.shape().clone(), self.value.device())) + } + + /// Returns a new tensor with the same shape and device as the current tensor filled random + /// values sampled from the given distribution. + pub fn random_like(&self, distribution: Distribution) -> Self { + Tensor::new(B::random( + self.shape().clone(), + distribution, + self.value.device(), + )) + } + + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let one_hot = Tensor::, 1>::one_hot(2, 10); + /// println!("{}", one_hot.to_data()); + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// ``` pub fn one_hot(index: usize, num_classes: usize) -> Self { let mut dims = [1; D]; dims[D - 1] = num_classes; @@ -71,147 +128,238 @@ where tensor } - pub fn ones_like(&self) -> Self { - Tensor::new(B::ones(self.shape().clone(), self.value.device())) - } - - pub fn random_like(&self, distribution: Distribution) -> Self { - Tensor::new(B::random( - self.shape().clone(), - distribution, - self.value.device(), - )) - } - + /// Apply element wise addition operation. + /// + /// `y = x2 + x1` pub fn add(&self, other: &Self) -> Self { Self::new(self.value.add(&other.value)) } + /// Apply element wise addition operation with a scalar. + /// + /// `y = x + s` pub fn add_scalar(&self, other: &B::Elem) -> Self { Self::new(self.value.add_scalar(&other)) } + /// Apply element wise substraction operation. + /// + /// `y = x2 - x1` pub fn sub(&self, other: &Self) -> Self { Self::new(self.value.sub(&other.value)) } + /// Apply element wise substraction operation with a scalar. + /// + /// `y = x - s` pub fn sub_scalar(&self, other: &B::Elem) -> Self { Self::new(self.value.sub_scalar(&other)) } + /// Apply the transpose operation. + /// + /// On matrix and higher dimension tensor, it swap the last two dimensions. + /// + /// # Panics + /// + /// If the tensor is of 1 dimension or less. pub fn transpose(&self) -> Self { Self::new(self.value.transpose()) } + /// Apply the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. pub fn matmul(&self, other: &Self) -> Self { Self::new(self.value.matmul(&other.value)) } + /// Switch sign of each element in the tensor. + /// + /// `y = -x` pub fn neg(&self) -> Self { Self::new(self.value.neg()) } + /// Apply element wise multiplication operation. + /// + /// `y = x2 * x1` pub fn mul(&self, other: &Self) -> Self { Self::new(self.value.mul(&other.value)) } + /// Apply element wise multiplication operation with scalar. + /// + /// `y = x2 * x1` pub fn mul_scalar(&self, other: &B::Elem) -> Self { Self::new(self.value.mul_scalar(&other)) } + /// Apply element wise division operation. + /// + /// `y = x2 / x1` pub fn div(&self, other: &Self) -> Self { Self::new(self.value.div(&other.value)) } + /// Apply element wise division operation with scalar. + /// + /// `y = x2 / x1` pub fn div_scalar(&self, other: &B::Elem) -> Self { Self::new(self.value.div_scalar(&other)) } - pub fn random(shape: Shape, distribution: Distribution) -> Self { - let tensor = B::random(shape, distribution, B::Device::default()); - Self::new(tensor) - } - + /// Aggregate all elements in the tensor with the mean operation. pub fn mean(&self) -> Tensor { Tensor::new(self.value.mean()) } + /// Aggregate all elements in the tensor with the sum operation. pub fn sum(&self) -> Tensor { Tensor::new(self.value.sum()) } + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the mean operation. pub fn mean_dim(&self, dim: usize) -> Self { Self::new(self.value.mean_dim(dim)) } + /// Aggregate all elements along the given *dimension* or *axis* in the tensor with the sum operation. pub fn sum_dim(&self, dim: usize) -> Self { Self::new(self.value.sum_dim(dim)) } + /// Apply element wise equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. pub fn equal(&self, other: &Self) -> BoolTensor { BoolTensor::new(self.value.equal(&other.value)) } - pub fn equal_scalar(&self, other: &B::Elem) -> BoolTensor { - BoolTensor::new(self.value.equal_scalar(other)) - } - + /// Apply element wise greater comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. pub fn greater(&self, other: &Self) -> BoolTensor { BoolTensor::new(self.value.greater(&other.value)) } + /// Apply element wise greater-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. pub fn greater_equal(&self, other: &Self) -> BoolTensor { BoolTensor::new(self.value.greater_equal(&other.value)) } - pub fn greater_scalar(&self, other: &B::Elem) -> BoolTensor { - BoolTensor::new(self.value.greater_scalar(other)) - } - - pub fn greater_equal_scalar(&self, other: &B::Elem) -> BoolTensor { - BoolTensor::new(self.value.greater_equal_scalar(other)) - } - + /// Apply element wise lower comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. pub fn lower(&self, other: &Self) -> BoolTensor { BoolTensor::new(self.value.lower(&other.value)) } + /// Apply element wise lower-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. pub fn lower_equal(&self, other: &Self) -> BoolTensor { BoolTensor::new(self.value.lower_equal(&other.value)) } + /// Apply element wise equal comparison and returns a boolean tensor. + pub fn equal_scalar(&self, other: &B::Elem) -> BoolTensor { + BoolTensor::new(self.value.equal_scalar(other)) + } + + /// Apply element wise greater comparison and returns a boolean tensor. + pub fn greater_scalar(&self, other: &B::Elem) -> BoolTensor { + BoolTensor::new(self.value.greater_scalar(other)) + } + + /// Apply element wise greater-equal comparison and returns a boolean tensor. + pub fn greater_equal_scalar(&self, other: &B::Elem) -> BoolTensor { + BoolTensor::new(self.value.greater_equal_scalar(other)) + } + + /// Apply element wise lower comparison and returns a boolean tensor. pub fn lower_scalar(&self, other: &B::Elem) -> BoolTensor { BoolTensor::new(self.value.lower_scalar(other)) } + /// Apply element wise lower-equal comparison and returns a boolean tensor. pub fn lower_equal_scalar(&self, other: &B::Elem) -> BoolTensor { BoolTensor::new(self.value.lower_equal_scalar(other)) } + /// Create a random tensor of the given shape where each element is sampled from the given + /// distribution. + pub fn random(shape: Shape, distribution: Distribution) -> Self { + let tensor = B::random(shape, distribution, B::Device::default()); + Self::new(tensor) + } + + /// Create a tensor of the given shape where each element is zero. pub fn zeros(shape: Shape) -> Self { let tensor = B::zeros(shape, B::Device::default()); Self::new(tensor) } + /// Create a tensor of the given shape where each element is one. pub fn ones(shape: Shape) -> Self { let tensor = B::ones(shape, B::Device::default()); Self::new(tensor) } - pub fn from_data(data: Data) -> Self { - let tensor = B::from_data(data, B::Device::default()); - Tensor::new(tensor) - } - - pub fn from_data_device(data: Data, device: B::Device) -> Self { - let tensor = B::from_data(data, device); - Tensor::new(tensor) - } - + /// Returns a tensor containing the elements selected from the given ranges. + /// + /// # Panics + /// + /// If a range exceeds the number of elements on a dimension. + /// + /// # Example + /// + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); + /// let tensor_indexed = tensor.index([0..1, 0..3, 1..2]); + /// println!("{:?}", tensor_indexed.shape()); + /// // Shape { dims: [1, 3, 2] } + /// ``` pub fn index(&self, indexes: [std::ops::Range; D2]) -> Self { Self::new(self.value.index(indexes)) } + /// Returns a copy of the current tensor with the selected elements changed to the new ones at + /// the selected indexes. + /// + /// # Panics + /// + /// - If a range exceeds the number of elements on a dimension. + /// - If the given values don't match the given ranges. + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); + /// let values = Tensor::, 3>::zeros(Shape::new([1, 1, 1])); + /// let tensor_indexed = tensor.index_assign([0..1, 0..1, 0..1], &values); + /// println!("{:?}", tensor_indexed.shape()); + /// // Shape { dims: [2, 3, 3] } + /// ``` pub fn index_assign( &self, indexes: [std::ops::Range; D2], From 83d08f0a2300a514ffe87ba0d5a03721e8b9fb21 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Sep 2022 11:46:07 -0400 Subject: [PATCH 2/5] Complete some docs --- burn-tensor/src/tensor/tensor.rs | 51 ++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index 24513681f8..eeedb2145d 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -368,27 +368,61 @@ where Self::new(self.value.index_assign(indexes, &values.value)) } + /// Fill each element with the given value based on the given mask. pub fn mask_fill(&self, mask: &BoolTensor, value: B::Elem) -> Self { Self::new(self.value.mask_fill(&mask.value, value)) } + /// Returns a tensor with full precision based on the selected backend. pub fn to_full_precision(&self) -> Tensor { Tensor::new(self.value.to_full_precision()) } + /// Returns a tensor on the selected backend from a full precision tensor. pub fn from_full_precision(tensor: Tensor) -> Self { let value = B::TensorPrimitive::from_full_precision(tensor.value); Tensor::new(value) } + /// Apply the argmax function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmax(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// ``` pub fn argmax(&self, dim: usize) -> Tensor { Tensor::new(self.value.argmax(dim)) } + /// Apply the argmin function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmin(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// ``` pub fn argmin(&self, dim: usize) -> Tensor { Tensor::new(self.value.argmin(dim)) } + /// Concatenates all tensors into a new one along the given dimension. + /// + /// # Panics + /// + /// If all tensors don't have the same shape. pub fn cat(tensors: Vec, dim: usize) -> Self { let tensors: Vec> = tensors.into_iter().map(|a| a.value.clone()).collect(); @@ -398,6 +432,23 @@ where Self::new(value) } + /// Unsqueeze the current tensor. Create new dimensions to fit the given size. + /// + /// # Panics + /// + /// If the output size is higher than the current tensor. + /// + /// # Example + /// + /// ```rust + /// use burn::backend::NdArrayBackend; + /// use burn::tensor::Tensor; + /// + /// let tensor = Tensor::, 3>::ones(Shape::new([3, 3])); + /// let tensor = tensor.unsqueeze::<4>(); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [1, 1, 3, 3] } + /// ``` pub fn unsqueeze(&self) -> Tensor { if D2 < D { panic!( From e154fb67a3a241a446322e01fc14b8803aefdcfd Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Sep 2022 11:50:41 -0400 Subject: [PATCH 3/5] Fix version --- burn-tensor/Cargo.toml | 8 ++++---- burn/Cargo.toml | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/burn-tensor/Cargo.toml b/burn-tensor/Cargo.toml index 9dfc23bdbd..7e9269d488 100644 --- a/burn-tensor/Cargo.toml +++ b/burn-tensor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "burn-tensor" -version = "0.2.1" +version = "0.2.2" authors = ["nathanielsimard "] description = """ @@ -15,15 +15,15 @@ license = "MIT/Apache-2.0" edition = "2021" [package.metadata.docs.rs] -features = ["ndarray"] +features = ["tch-doc", "ndarray"] all-features = false no-default-features = true [features] -default = ["ndarray"] -full = ["tch", "ndarray"] +default = ["tch", "ndarray"] tch = ["dep:tch"] ndarray = ["dep:ndarray"] +tch-doc = ["dep:tch", "tch/doc-only"] [dependencies] num-traits = "0.2" diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 1803a7d7f7..81e18888b0 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "burn" -version = "0.2.1" +version = "0.2.2" authors = ["nathanielsimard "] description = "BURN: Burn Unstoppable Rusty Neurons" repository = "https://github.com/nathanielsimard/burn" @@ -13,10 +13,11 @@ edition = "2021" [features] default = ["tch", "ndarray"] tch = ["burn-tensor/tch"] +tch-doc = ["burn-tensor/tch-doc"] ndarray = ["burn-tensor/ndarray"] [package.metadata.docs.rs] -features = ["ndarray"] +features = ["tch-doc", "ndarray"] all-features = false no-default-features = true From 00762ee2d9e6f780f508b80ead02aaecd6979629 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 9 Sep 2022 18:01:41 -0400 Subject: [PATCH 4/5] Fix doc tests --- .github/workflows/test-burn-dataset.yml | 7 +- .github/workflows/test-burn-tensor.yml | 9 ++- .github/workflows/test-burn.yml | 8 +- burn-tensor/Cargo.toml | 5 +- burn-tensor/src/tensor/tensor.rs | 99 ++++++++++++++----------- burn/Cargo.toml | 4 +- 6 files changed, 82 insertions(+), 50 deletions(-) diff --git a/.github/workflows/test-burn-dataset.yml b/.github/workflows/test-burn-dataset.yml index 822b1541f8..5886605b6d 100644 --- a/.github/workflows/test-burn-dataset.yml +++ b/.github/workflows/test-burn-dataset.yml @@ -23,7 +23,12 @@ jobs: cd burn-dataset cargo fmt --check --all + - name: check doc + run: | + cd burn-tensor + cargo test --doc + - name: check tests run: | cd burn-dataset - cargo test + cargo test --tests diff --git a/.github/workflows/test-burn-tensor.yml b/.github/workflows/test-burn-tensor.yml index 7f6e556ef2..7cdebc095a 100644 --- a/.github/workflows/test-burn-tensor.yml +++ b/.github/workflows/test-burn-tensor.yml @@ -23,14 +23,19 @@ jobs: cd burn-tensor cargo fmt --check --all + - name: check doc + run: | + cd burn-tensor + cargo test --no-default-features --features doc --doc + - name: check tests backend ndarray run: | cd burn-tensor - cargo test --no-default-features --features ndarray + cargo test --no-default-features --features ndarray --tests - name: check tests backend tch run: | cd burn-tensor - cargo test --no-default-features --features tch + cargo test --no-default-features --features tch --tests diff --git a/.github/workflows/test-burn.yml b/.github/workflows/test-burn.yml index 130ffcf563..a8a414c5c5 100644 --- a/.github/workflows/test-burn.yml +++ b/.github/workflows/test-burn.yml @@ -23,7 +23,13 @@ jobs: cd burn cargo fmt --check --all + + - name: check doc + run: | + cd burn-tensor + cargo test --no-default-features --features doc --doc + - name: check tests run: | cd burn - cargo test + cargo test --tests diff --git a/burn-tensor/Cargo.toml b/burn-tensor/Cargo.toml index 7e9269d488..a47035c023 100644 --- a/burn-tensor/Cargo.toml +++ b/burn-tensor/Cargo.toml @@ -13,9 +13,10 @@ keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] categories = ["science"] license = "MIT/Apache-2.0" edition = "2021" +doctest = false [package.metadata.docs.rs] -features = ["tch-doc", "ndarray"] +features = ["doc"] all-features = false no-default-features = true @@ -23,7 +24,7 @@ no-default-features = true default = ["tch", "ndarray"] tch = ["dep:tch"] ndarray = ["dep:ndarray"] -tch-doc = ["dep:tch", "tch/doc-only"] +doc = ["dep:tch", "tch/doc-only", "dep:ndarray"] [dependencies] num-traits = "0.2" diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index eeedb2145d..422a565fbd 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -107,12 +107,14 @@ where /// # Example /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let one_hot = Tensor::, 1>::one_hot(2, 10); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example() { + /// let one_hot = Tensor::::one_hot(2, 10); + /// println!("{}", one_hot.to_data()); + /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + /// } /// ``` pub fn one_hot(index: usize, num_classes: usize) -> Self { let mut dims = [1; D]; @@ -331,13 +333,15 @@ where /// # Example /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); - /// let tensor_indexed = tensor.index([0..1, 0..3, 1..2]); - /// println!("{:?}", tensor_indexed.shape()); - /// // Shape { dims: [1, 3, 2] } + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor_indexed = tensor.index([0..1, 0..3, 1..2]); + /// println!("{:?}", tensor_indexed.shape()); + /// // Shape { dims: [1, 3, 2] } + /// } /// ``` pub fn index(&self, indexes: [std::ops::Range; D2]) -> Self { Self::new(self.value.index(indexes)) @@ -350,15 +354,20 @@ where /// /// - If a range exceeds the number of elements on a dimension. /// - If the given values don't match the given ranges. + /// + /// # Example + /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); - /// let values = Tensor::, 3>::zeros(Shape::new([1, 1, 1])); - /// let tensor_indexed = tensor.index_assign([0..1, 0..1, 0..1], &values); - /// println!("{:?}", tensor_indexed.shape()); - /// // Shape { dims: [2, 3, 3] } + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let values = Tensor::::zeros(Shape::new([1, 1, 1])); + /// let tensor_indexed = tensor.index_assign([0..1, 0..1, 0..1], &values); + /// println!("{:?}", tensor_indexed.shape()); + /// // Shape { dims: [2, 3, 3] } + /// } /// ``` pub fn index_assign( &self, @@ -389,13 +398,15 @@ where /// # Example /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmax(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmax(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } /// ``` pub fn argmax(&self, dim: usize) -> Tensor { Tensor::new(self.value.argmax(dim)) @@ -406,13 +417,15 @@ where /// # Example /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let tensor = Tensor::, 3>::ones(Shape::new([2, 3, 3])); - /// let tensor = tensor.argmin(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3])); + /// let tensor = tensor.argmin(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } /// ``` pub fn argmin(&self, dim: usize) -> Tensor { Tensor::new(self.value.argmin(dim)) @@ -441,13 +454,15 @@ where /// # Example /// /// ```rust - /// use burn::backend::NdArrayBackend; - /// use burn::tensor::Tensor; - /// - /// let tensor = Tensor::, 3>::ones(Shape::new([3, 3])); - /// let tensor = tensor.unsqueeze::<4>(); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 1, 3, 3] } + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([3, 3])); + /// let tensor = tensor.unsqueeze::<4>(); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [1, 1, 3, 3] } + /// } /// ``` pub fn unsqueeze(&self) -> Tensor { if D2 < D { diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 81e18888b0..d353ac5110 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -13,11 +13,11 @@ edition = "2021" [features] default = ["tch", "ndarray"] tch = ["burn-tensor/tch"] -tch-doc = ["burn-tensor/tch-doc"] ndarray = ["burn-tensor/ndarray"] +doc = ["burn-tensor/doc"] [package.metadata.docs.rs] -features = ["tch-doc", "ndarray"] +features = ["doc"] all-features = false no-default-features = true From 41b2d9f452e75605621bc4ed1c44125184cea05b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 9 Sep 2022 18:06:39 -0400 Subject: [PATCH 5/5] Fix ci --- .github/workflows/test-burn-dataset.yml | 2 +- .github/workflows/test-burn.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-burn-dataset.yml b/.github/workflows/test-burn-dataset.yml index 5886605b6d..40a90d6bf2 100644 --- a/.github/workflows/test-burn-dataset.yml +++ b/.github/workflows/test-burn-dataset.yml @@ -25,7 +25,7 @@ jobs: - name: check doc run: | - cd burn-tensor + cd burn-dataset cargo test --doc - name: check tests diff --git a/.github/workflows/test-burn.yml b/.github/workflows/test-burn.yml index a8a414c5c5..47232eb3ee 100644 --- a/.github/workflows/test-burn.yml +++ b/.github/workflows/test-burn.yml @@ -26,7 +26,7 @@ jobs: - name: check doc run: | - cd burn-tensor + cd burn cargo test --no-default-features --features doc --doc - name: check tests