Skip to content

Commit

Permalink
Fix/from data copy (#2778)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 5, 2025
1 parent 28851ee commit b9653f5
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 182 deletions.
12 changes: 12 additions & 0 deletions crates/burn-fusion/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use crate::{stream::Operation, FusionBackend};
use burn_tensor::repr::HandleContainer;
use std::marker::PhantomData;

#[derive(new)]
pub struct NoOp<B: FusionBackend> {
_b: PhantomData<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for NoOp<B> {
fn execute(self: Box<Self>, _handles: &mut HandleContainer<B::Handle>) {}
}
33 changes: 11 additions & 22 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn_tensor::{
ops::{binary_ops_shape, FloatTensor, IntTensor},
repr::{FromDataOperationDescription, TensorDescription},
DType, Element, TensorData,
repr::{InitOperationDescription, TensorDescription},
DType, Element, TensorData, TensorMetadata,
};
use std::marker::PhantomData;

Expand All @@ -23,6 +23,8 @@ use burn_tensor::{
Device, Shape,
};

use super::NoOp;

impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new)]
Expand Down Expand Up @@ -58,32 +60,19 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}

fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_from_data(self.desc.data, &self.device);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool);
let tensor = B::bool_from_data(data, device);
let shape = tensor.shape();

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};
let handle = B::bool_tensor_handle(tensor);
let out = client.register_tensor(handle, shape.dims, stream, DType::Bool);
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
OperationDescription::Init(InitOperationDescription { out: desc }),
NoOp::<B>::new(),
);

out
Expand Down
32 changes: 11 additions & 21 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,28 @@ use crate::{
use burn_tensor::{
ops::{binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
repr::*,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata,
};
use std::{marker::PhantomData, ops::Range};

use super::NoOp;

impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::float_from_data(self.desc.data, &self.device);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype());
let dtype = data.dtype;
let tensor = B::float_from_data(data, device);
let shape = tensor.shape();

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};
let handle = B::float_tensor_handle(tensor);
let out = client.register_tensor(handle, shape.dims, stream, dtype);
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
OperationDescription::Init(InitOperationDescription { out: desc }),
NoOp::<B>::new(),
);

out
Expand Down
32 changes: 11 additions & 21 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ use crate::{
use burn_tensor::{
ops::{binary_ops_shape, BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
repr::{self, *},
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata,
};
use core::ops::Range;
use std::marker::PhantomData;

use super::NoOp;

impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
#[derive(new)]
Expand Down Expand Up @@ -48,32 +50,20 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}

fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::int_from_data(self.desc.data, &self.device);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype());
let dtype = data.dtype;
let tensor = B::int_from_data(data, device);
let shape = tensor.shape();

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};
let handle = B::int_tensor_handle(tensor);
let out = client.register_tensor(handle, shape.dims, stream, dtype);
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
OperationDescription::Init(InitOperationDescription { out: desc }),
NoOp::<B>::new(),
);

out
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-fusion/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ mod module;
mod qtensor;
mod transaction;
mod unary;

mod base;
pub(crate) use base::*;
62 changes: 20 additions & 42 deletions crates/burn-fusion/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use burn_tensor::{
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
repr::{
BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription,
FromDataOperationDescription, HandleContainer, OperationDescription,
QuantizationParametersDescription, QuantizeOperationDescription,
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
InitOperationDescription, OperationDescription, QuantizationParametersDescription,
QuantizeOperationDescription,
},
DType, Device, Element, Shape, TensorData,
DType, Device, Element, Shape, TensorData, TensorMetadata,
};

use crate::{
Expand All @@ -18,49 +18,27 @@ use crate::{
Fusion, FusionBackend,
};

use super::NoOp;

impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::q_from_data(self.desc.data, &self.device);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}
let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let dtype = data.dtype;
let tensor = B::q_from_data(data, device);
let shape = tensor.shape();

match data.dtype {
DType::QFloat(_scheme) => {
let dtype = data.dtype;
let handle = B::quantized_tensor_handle(tensor);
let out = client.register_tensor(handle, shape.dims, stream, dtype);
let desc = out.to_description_out();

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), dtype);

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};

client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::FromData(
desc.clone(),
)),
FromDataOps::<B>::new(desc, device.clone()),
);
client.register(
vec![stream],
OperationDescription::Init(InitOperationDescription { out: desc }),
NoOp::<B>::new(),
);

out
}
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
data.dtype
),
}
out
}

fn quantize(
Expand Down
20 changes: 13 additions & 7 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ impl OperationConverter {
burn_tensor::DType::F32 => self.scalar_f32.push(elem.elem()),
burn_tensor::DType::F16 => self.scalar_f16.push(elem.elem()),
burn_tensor::DType::BF16 => self.scalar_bf16.push(elem.elem()),
_ => todo!("Unsupported"),
_ => todo!("Unsupported float dtype ({dtype:?}) for scalar ({elem:?})"),
}

// We return 0 so that the id from a scalar operation is the same no matter its scalar
Expand Down Expand Up @@ -227,6 +227,7 @@ impl OperationConverter {

impl RelativeOps for OperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
println!("To relative {self:?}");
match self {
OperationDescription::BaseFloat(ops) => {
OperationDescription::BaseFloat(ops.to_relative(converter))
Expand Down Expand Up @@ -261,6 +262,9 @@ impl RelativeOps for OperationDescription {
OperationDescription::Custom(ops) => {
OperationDescription::Custom(ops.to_relative(converter))
}
OperationDescription::Init(ops) => {
OperationDescription::Init(ops.to_relative(converter))
}
}
}
}
Expand Down Expand Up @@ -1210,12 +1214,14 @@ impl RelativeOps for BaseOperationDescription {
BaseOperationDescription::Empty(desc) => {
BaseOperationDescription::Empty(desc.to_relative(converter))
}
BaseOperationDescription::FromData(desc) => {
BaseOperationDescription::FromData(FromDataOperationDescription {
data: desc.data.clone(),
out: desc.out.to_relative(converter),
})
}
}
}
}

impl RelativeOps for InitOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
Self {
out: self.out.to_relative(converter),
}
}
}
Expand Down
12 changes: 4 additions & 8 deletions crates/burn-router/src/ops/op_bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle
use burn_tensor::repr::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
FromDataOperationDescription, OperationDescription, PermuteOperationDescription,
InitOperationDescription, OperationDescription, PermuteOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
};
Expand All @@ -31,16 +31,12 @@ impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {

fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
let out = client.register_empty_tensor(data.shape.clone(), DType::Bool);

let desc = FromDataOperationDescription {
data,
let out = client.register_tensor_data(data);
let desc = InitOperationDescription {
out: out.to_description_out(),
};

client.register(OperationDescription::BaseBool(
BaseOperationDescription::FromData(desc),
));
client.register(OperationDescription::Init(desc));

out
}
Expand Down
12 changes: 4 additions & 8 deletions crates/burn-router/src/ops/op_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use burn_tensor::ops::{
use burn_tensor::repr::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription,
FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription,
FloatOperationDescription, GatherOperationDescription, InitOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription,
Expand All @@ -25,16 +25,12 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient};
impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
let client = get_client::<R>(device);
let out = client.register_empty_tensor(data.shape.clone(), FloatElem::<Self>::dtype());

let desc = FromDataOperationDescription {
data,
let out = client.register_tensor_data(data);
let desc = InitOperationDescription {
out: out.to_description_out(),
};

client.register(OperationDescription::BaseFloat(
BaseOperationDescription::FromData(desc),
));
client.register(OperationDescription::Init(desc));

out
}
Expand Down
Loading

0 comments on commit b9653f5

Please # to comment.