diff --git a/crates/burn-fusion/src/ops/base.rs b/crates/burn-fusion/src/ops/base.rs new file mode 100644 index 0000000000..2a4182b3d5 --- /dev/null +++ b/crates/burn-fusion/src/ops/base.rs @@ -0,0 +1,12 @@ +use crate::{stream::Operation, FusionBackend}; +use burn_tensor::repr::HandleContainer; +use std::marker::PhantomData; + +#[derive(new)] +pub struct NoOp { + _b: PhantomData, +} + +impl Operation for NoOp { + fn execute(self: Box, _handles: &mut HandleContainer) {} +} diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 658907bf3e..2008ece60a 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,7 +1,7 @@ use burn_tensor::{ ops::{binary_ops_shape, FloatTensor, IntTensor}, - repr::{FromDataOperationDescription, TensorDescription}, - DType, Element, TensorData, + repr::{InitOperationDescription, TensorDescription}, + DType, Element, TensorData, TensorMetadata, }; use std::marker::PhantomData; @@ -23,6 +23,8 @@ use burn_tensor::{ Device, Shape, }; +use super::NoOp; + impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { #[derive(new)] @@ -58,32 +60,19 @@ impl BoolTensorOps for Fusion { } fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { - #[derive(new)] - struct FromDataOps { - desc: FromDataOperationDescription, - device: Device, - } - - impl Operation for FromDataOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let output = B::bool_from_data(self.desc.data, &self.device); - handles.register_bool_tensor::(&self.desc.out.id, output); - } - } - let stream = StreamId::current(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool); + let tensor = B::bool_from_data(data, device); + let shape = tensor.shape(); - let desc = FromDataOperationDescription { - out: out.to_description_out(), - data, - }; + let handle = B::bool_tensor_handle(tensor); + let out = client.register_tensor(handle, shape.dims, stream, DType::Bool); + let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())), - FromDataOps::::new(desc, device.clone()), + OperationDescription::Init(InitOperationDescription { out: desc }), + NoOp::::new(), ); out diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 1ba2717bfb..5a7826e567 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -10,38 +10,28 @@ use crate::{ use burn_tensor::{ ops::{binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, repr::*, - DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, }; use std::{marker::PhantomData, ops::Range}; +use super::NoOp; + impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { - #[derive(new)] - struct FromDataOps { - desc: FromDataOperationDescription, - device: Device, - } - - impl Operation for FromDataOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let output = B::float_from_data(self.desc.data, &self.device); - handles.register_float_tensor::(&self.desc.out.id, output); - } - } - let stream = StreamId::current(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype()); + let dtype = data.dtype; + let tensor = B::float_from_data(data, device); + let shape = tensor.shape(); - let desc = FromDataOperationDescription { - out: out.to_description_out(), - data, - }; + let handle = B::float_tensor_handle(tensor); + let out = client.register_tensor(handle, shape.dims, stream, dtype); + let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())), - FromDataOps::::new(desc, device.clone()), + OperationDescription::Init(InitOperationDescription { out: desc }), + NoOp::::new(), ); out diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bf88bbd25b..27bb83329f 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -8,11 +8,13 @@ use crate::{ use burn_tensor::{ ops::{binary_ops_shape, BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, repr::{self, *}, - DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, }; use core::ops::Range; use std::marker::PhantomData; +use super::NoOp; + impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { #[derive(new)] @@ -48,32 +50,20 @@ impl IntTensorOps for Fusion { } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - #[derive(new)] - struct FromDataOps { - desc: FromDataOperationDescription, - device: Device, - } - - impl Operation for FromDataOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let output = B::int_from_data(self.desc.data, &self.device); - handles.register_int_tensor::(&self.desc.out.id, output); - } - } - let stream = StreamId::current(); let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype()); + let dtype = data.dtype; + let tensor = B::int_from_data(data, device); + let shape = tensor.shape(); - let desc = FromDataOperationDescription { - out: out.to_description_out(), - data, - }; + let handle = B::int_tensor_handle(tensor); + let out = client.register_tensor(handle, shape.dims, stream, dtype); + let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())), - FromDataOps::::new(desc, device.clone()), + OperationDescription::Init(InitOperationDescription { out: desc }), + NoOp::::new(), ); out diff --git a/crates/burn-fusion/src/ops/mod.rs b/crates/burn-fusion/src/ops/mod.rs index 074677aed7..1cf3346e60 100644 --- a/crates/burn-fusion/src/ops/mod.rs +++ b/crates/burn-fusion/src/ops/mod.rs @@ -7,3 +7,6 @@ mod module; mod qtensor; mod transaction; mod unary; + +mod base; +pub(crate) use base::*; diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 1449a485af..af084351e7 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -4,11 +4,11 @@ use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ - BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription, - FromDataOperationDescription, HandleContainer, OperationDescription, - QuantizationParametersDescription, QuantizeOperationDescription, + DequantizeOperationDescription, FloatOperationDescription, HandleContainer, + InitOperationDescription, OperationDescription, QuantizationParametersDescription, + QuantizeOperationDescription, }, - DType, Device, Element, Shape, TensorData, + DType, Device, Element, Shape, TensorData, TensorMetadata, }; use crate::{ @@ -18,49 +18,27 @@ use crate::{ Fusion, FusionBackend, }; +use super::NoOp; + impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { - #[derive(new)] - struct FromDataOps { - desc: FromDataOperationDescription, - device: Device, - } - - impl Operation for FromDataOps { - fn execute(self: Box, handles: &mut HandleContainer) { - let output = B::q_from_data(self.desc.data, &self.device); - handles.register_quantized_tensor::(&self.desc.out.id, output); - } - } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let dtype = data.dtype; + let tensor = B::q_from_data(data, device); + let shape = tensor.shape(); - match data.dtype { - DType::QFloat(_scheme) => { - let dtype = data.dtype; + let handle = B::quantized_tensor_handle(tensor); + let out = client.register_tensor(handle, shape.dims, stream, dtype); + let desc = out.to_description_out(); - let stream = StreamId::current(); - let client = get_client::(&device.clone()); - let out = client.tensor_uninitialized(data.shape.clone(), dtype); - - let desc = FromDataOperationDescription { - out: out.to_description_out(), - data, - }; - - client.register( - vec![stream], - OperationDescription::BaseFloat(BaseOperationDescription::FromData( - desc.clone(), - )), - FromDataOps::::new(desc, device.clone()), - ); + client.register( + vec![stream], + OperationDescription::Init(InitOperationDescription { out: desc }), + NoOp::::new(), + ); - out - } - _ => panic!( - "Invalid dtype (expected DType::QFloat, got {:?})", - data.dtype - ), - } + out } fn quantize( diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index ed1a1902f8..22e51abe06 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -199,7 +199,7 @@ impl OperationConverter { burn_tensor::DType::F32 => self.scalar_f32.push(elem.elem()), burn_tensor::DType::F16 => self.scalar_f16.push(elem.elem()), burn_tensor::DType::BF16 => self.scalar_bf16.push(elem.elem()), - _ => todo!("Unsupported"), + _ => todo!("Unsupported float dtype ({dtype:?}) for scalar ({elem:?})"), } // We return 0 so that the id from a scalar operation is the same no matter its scalar @@ -227,6 +227,7 @@ impl OperationConverter { impl RelativeOps for OperationDescription { fn to_relative(&self, converter: &mut OperationConverter) -> Self { + println!("To relative {self:?}"); match self { OperationDescription::BaseFloat(ops) => { OperationDescription::BaseFloat(ops.to_relative(converter)) @@ -261,6 +262,9 @@ impl RelativeOps for OperationDescription { OperationDescription::Custom(ops) => { OperationDescription::Custom(ops.to_relative(converter)) } + OperationDescription::Init(ops) => { + OperationDescription::Init(ops.to_relative(converter)) + } } } } @@ -1210,12 +1214,14 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::Empty(desc) => { BaseOperationDescription::Empty(desc.to_relative(converter)) } - BaseOperationDescription::FromData(desc) => { - BaseOperationDescription::FromData(FromDataOperationDescription { - data: desc.data.clone(), - out: desc.out.to_relative(converter), - }) - } + } + } +} + +impl RelativeOps for InitOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { + Self { + out: self.out.to_relative(converter), } } } diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 5d01ddd7d3..8e21275213 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -4,7 +4,7 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FromDataOperationDescription, OperationDescription, PermuteOperationDescription, + InitOperationDescription, OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; @@ -31,16 +31,12 @@ impl BoolTensorOps for BackendRouter { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); - let out = client.register_empty_tensor(data.shape.clone(), DType::Bool); - - let desc = FromDataOperationDescription { - data, + let out = client.register_tensor_data(data); + let desc = InitOperationDescription { out: out.to_description_out(), }; - client.register(OperationDescription::BaseBool( - BaseOperationDescription::FromData(desc), - )); + client.register(OperationDescription::Init(desc)); out } diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 1cf211701c..e235a7a9b7 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -8,7 +8,7 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, + FloatOperationDescription, GatherOperationDescription, InitOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, @@ -25,16 +25,12 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); - let out = client.register_empty_tensor(data.shape.clone(), FloatElem::::dtype()); - - let desc = FromDataOperationDescription { - data, + let out = client.register_tensor_data(data); + let desc = InitOperationDescription { out: out.to_description_out(), }; - client.register(OperationDescription::BaseFloat( - BaseOperationDescription::FromData(desc), - )); + client.register(OperationDescription::Init(desc)); out } diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 997bf5b9e6..83771dadd7 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -8,7 +8,7 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, + GatherOperationDescription, InitOperationDescription, IntOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, @@ -45,16 +45,12 @@ impl IntTensorOps for BackendRouter { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); - let out = client.register_empty_tensor(data.shape.clone(), IntElem::::dtype()); - - let desc = FromDataOperationDescription { - data, + let out = client.register_tensor_data(data); + let desc = InitOperationDescription { out: out.to_description_out(), }; - client.register(OperationDescription::BaseInt( - BaseOperationDescription::FromData(desc), - )); + client.register(OperationDescription::Init(desc)); out } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 7443be94f9..882d05ff2a 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -169,7 +169,7 @@ impl RunnerClient for Runner { ctx.free_orphans(); let handles = &mut ctx.handles; - match &op { + match op { // For every op: get the input(s), execute the operation and register the output(s) OperationDescription::BaseFloat(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -245,10 +245,6 @@ impl RunnerClient for Runner { let output = B::float_empty(shape, &self.device); handles.register_float_tensor::(&desc.id, output); } - BaseOperationDescription::FromData(desc) => { - let output = B::float_from_data(desc.data.clone(), &self.device); - handles.register_float_tensor::(&desc.out.id, output); - } }, OperationDescription::BaseInt(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -320,10 +316,6 @@ impl RunnerClient for Runner { let output = B::int_empty(shape, &self.device); handles.register_int_tensor::(&desc.id, output); } - BaseOperationDescription::FromData(desc) => { - let output = B::int_from_data(desc.data.clone(), &self.device); - handles.register_int_tensor::(&desc.out.id, output); - } }, OperationDescription::BaseBool(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -399,10 +391,6 @@ impl RunnerClient for Runner { let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.id, output); } - BaseOperationDescription::FromData(desc) => { - let output = B::bool_from_data(desc.data.clone(), &self.device); - handles.register_bool_tensor::(&desc.out.id, output); - } }, OperationDescription::NumericFloat(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { @@ -1219,6 +1207,9 @@ impl RunnerClient for Runner { OperationDescription::Custom(_) => { panic!("Can't execute custom operation here") } + OperationDescription::Init(_) => { + // Nothing to do. + } } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index e4b0f3ccaf..fdc9f32cc9 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -6,7 +6,6 @@ use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec, vec::Vec}; -use crate::TensorData; use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, @@ -81,6 +80,8 @@ pub enum OperationDescription { Float(DType, FloatOperationDescription), /// Module operation. Module(ModuleOperationDescription), + /// Initialize operation. + Init(InitOperationDescription), /// A custom operation. Custom(CustomOpDescription), } @@ -198,12 +199,6 @@ pub enum ModuleOperationDescription { /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationDescription { - /// Operation corresponding to: - /// - /// Float => [from_data](crate::ops::FloatTensorOps::float_from_data). - /// Int => [from_data](crate::ops::IntTensorOps::int_from_data). - /// Bool => [from_data](crate::ops::BoolTensorOps::bool_from_data). - FromData(FromDataOperationDescription), /// Operation corresponding to: /// /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). @@ -638,10 +633,12 @@ pub struct RandomOperationDescription { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct FromDataOperationDescription { +/// Declares a tensor has been initialized. +/// +/// It is necessary to register for proper orphan detection and avoid memory leak. +pub struct InitOperationDescription { + /// The initialized tensor. pub out: TensorDescription, - pub data: TensorData, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] @@ -1381,6 +1378,7 @@ impl OperationDescription { OperationDescription::Int(ops) => ops.nodes(), OperationDescription::Float(_dtype, ops) => ops.nodes(), OperationDescription::Module(ops) => ops.nodes(), + OperationDescription::Init(ops) => ops.nodes(), OperationDescription::Custom(ops) => ops.nodes(), } } @@ -1422,7 +1420,6 @@ impl BaseOperationDescription { BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], BaseOperationDescription::Empty(desc) => vec![desc], - BaseOperationDescription::FromData(desc) => vec![&desc.out], } } } @@ -1769,12 +1766,18 @@ impl ModuleOperationDescription { } } -impl core::hash::Hash for FromDataOperationDescription { +impl core::hash::Hash for InitOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state); } } +impl InitOperationDescription { + fn nodes(&self) -> Vec<&TensorDescription> { + vec![&self.out] + } +} + impl core::hash::Hash for RandomOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state); diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 4bbc522f49..aa62952a05 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -960,6 +960,19 @@ where Self::new(K::from_data(data, device)) } + /// Create a tensor from the given data on the given device enforcing the given data type. + pub fn from_data_dtype(data: T, device: &B::Device, dtype: DType) -> Self + where + T: Into, + { + let data = data.into(); + check!(TensorCheck::creation_ops::( + "From Data", + data.shape.as_slice() + )); + Self::new(K::from_data_dtype(data, device, dtype)) + } + /// Repeat the tensor along the given dimension. /// /// @@ -2155,6 +2168,17 @@ pub trait BasicOps: TensorKind { /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, /// which is more high-level and designed for public use. fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive; + /// Creates a tensor from the given data enforcing the given data type. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype) + /// function, which is more high-level and designed for public use. + fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive; /// Repeat the tensor along the given dimension. /// @@ -2501,7 +2525,16 @@ impl BasicOps for Float { fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { match data.dtype { DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)), - _ => TensorPrimitive::Float(B::float_from_data(data, device)), + _ => TensorPrimitive::Float(B::float_from_data(data.convert::(), device)), + } + } + + fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive { + match dtype { + DType::QFloat(_strategy) => { + TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device)) + } + _ => TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device)), } } @@ -2674,7 +2707,11 @@ impl BasicOps for Int { } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { - B::int_from_data(data, device) + B::int_from_data(data.convert::(), device) + } + + fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive { + B::int_from_data(data.convert_dtype(dtype), device) } fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive { @@ -2786,7 +2823,11 @@ impl BasicOps for Bool { } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { - B::bool_from_data(data, device) + B::bool_from_data(data.convert::(), device) + } + + fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive { + B::bool_from_data(data.convert_dtype(dtype), device) } fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive { diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index bd144e397f..413d230a06 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -315,28 +315,66 @@ impl TensorData { /// Converts the data to a different element type. pub fn convert(self) -> Self { - if E::dtype() == self.dtype { + self.convert_dtype(E::dtype()) + } + + /// Converts the data to a different element type. + pub fn convert_dtype(self, dtype: DType) -> Self { + if dtype == self.dtype { self - } else if core::mem::size_of::() == self.dtype.size() + } else if dtype.size() == self.dtype.size() && !matches!(self.dtype, DType::Bool | DType::QFloat(_)) { match self.dtype { - DType::F64 => self.convert_inplace::(), - DType::F32 => self.convert_inplace::(), - DType::F16 => self.convert_inplace::(), - DType::BF16 => self.convert_inplace::(), - DType::I64 => self.convert_inplace::(), - DType::I32 => self.convert_inplace::(), - DType::I16 => self.convert_inplace::(), - DType::I8 => self.convert_inplace::(), - DType::U64 => self.convert_inplace::(), - DType::U32 => self.convert_inplace::(), - DType::U16 => self.convert_inplace::(), - DType::U8 => self.convert_inplace::(), + DType::F64 => self.convert_inplace_dtype::(dtype), + DType::F32 => self.convert_inplace_dtype::(dtype), + DType::F16 => self.convert_inplace_dtype::(dtype), + DType::BF16 => self.convert_inplace_dtype::(dtype), + DType::I64 => self.convert_inplace_dtype::(dtype), + DType::I32 => self.convert_inplace_dtype::(dtype), + DType::I16 => self.convert_inplace_dtype::(dtype), + DType::I8 => self.convert_inplace_dtype::(dtype), + DType::U64 => self.convert_inplace_dtype::(dtype), + DType::U32 => self.convert_inplace_dtype::(dtype), + DType::U16 => self.convert_inplace_dtype::(dtype), + DType::U8 => self.convert_inplace_dtype::(dtype), DType::Bool | DType::QFloat(_) => unreachable!(), } } else { - TensorData::new(self.iter::().collect(), self.shape) + match dtype { + DType::F64 => TensorData::new(self.iter::().collect(), self.shape), + DType::F32 => TensorData::new(self.iter::().collect(), self.shape), + DType::F16 => TensorData::new(self.iter::().collect(), self.shape), + DType::BF16 => TensorData::new(self.iter::().collect(), self.shape), + DType::I64 => TensorData::new(self.iter::().collect(), self.shape), + DType::I32 => TensorData::new(self.iter::().collect(), self.shape), + DType::I16 => TensorData::new(self.iter::().collect(), self.shape), + DType::I8 => TensorData::new(self.iter::().collect(), self.shape), + DType::U64 => TensorData::new(self.iter::().collect(), self.shape), + DType::U32 => TensorData::new(self.iter::().collect(), self.shape), + DType::U16 => TensorData::new(self.iter::().collect(), self.shape), + DType::U8 => TensorData::new(self.iter::().collect(), self.shape), + DType::Bool => TensorData::new(self.iter::().collect(), self.shape), + DType::QFloat(_) => unreachable!(), + } + } + } + + fn convert_inplace_dtype(self, dtype: DType) -> Self { + match dtype { + DType::F64 => self.convert_inplace::(), + DType::F32 => self.convert_inplace::(), + DType::F16 => self.convert_inplace::(), + DType::BF16 => self.convert_inplace::(), + DType::I64 => self.convert_inplace::(), + DType::I32 => self.convert_inplace::(), + DType::I16 => self.convert_inplace::(), + DType::I8 => self.convert_inplace::(), + DType::U64 => self.convert_inplace::(), + DType::U32 => self.convert_inplace::(), + DType::U16 => self.convert_inplace::(), + DType::U8 => self.convert_inplace::(), + DType::Bool | DType::QFloat(_) => unreachable!(), } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 13d6cf51df..5bc6c0cf5f 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -97,7 +97,7 @@ mod tests { let x = x_q.dequantize(); // Precision 2 for dequantization errors - x.to_data() + x.into_data() .assert_approx_eq(&TensorData::from([-1.8, -1.0, 0.0, 0.5]), 2); } @@ -117,6 +117,6 @@ mod tests { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)), ); - x_q.to_data().assert_eq(&expected, false); + x_q.into_data().assert_eq(&expected, false); } }