Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Changes how vectorization is applied during fusion #2833

Merged
merged 14 commits into from
Feb 21, 2025
33 changes: 17 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8b025f26e5badbf1b8f3e6787fc097427cd961ec" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8b025f26e5badbf1b8f3e6787fc097427cd961ec" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b41cbd82d53f091e76f56cad58c277fe2481c48e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b41cbd82d53f091e76f56cad58c277fe2481c48e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/tests/nonzero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod tests {
.clone()
.unsqueeze_dim::<2>(0)
.matmul(tensor_3.unsqueeze_dim(1));

let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
Expand Down
1 change: 0 additions & 1 deletion crates/burn-cubecl-fusion/src/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ impl<R: Runtime> ElementWiseBuilder<R> {
FuseSettings {
broadcast: true,
output_shape_updates: true,
mix_vectorization: true,
inplace: true,
},
),
Expand Down
9 changes: 3 additions & 6 deletions crates/burn-cubecl-fusion/src/elemwise/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,14 @@ impl<R: Runtime> TraceRunner<R> for ElemwiseRunner {
Arg::Output(index, _, _) => outputs.tensors.values.get(index as usize),
_ => panic!("Invalid value"),
};
let (shape, vectorization) = match arg {
let shape = match arg {
Some(val) => match &val.tensor {
TensorArg::Handle {
handle,
vectorization_factor,
} => (handle.shape, vectorization_factor),
TensorArg::Handle { handle, .. } => handle.shape,
TensorArg::Alias { .. } => panic!("Can't be an alias, got {val:?}"),
},
None => panic!("Invalid argument"),
};
let total_elem = shape.iter().product::<usize>() / *vectorization as usize;
let total_elem = shape.iter().product::<usize>() / config.width as usize;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-cubecl-fusion/src/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,5 @@ impl Init for FusedMatmulStateExpand {
self
}
}

impl CubeDebug for FusedMatmulStateExpand {}
1 change: 0 additions & 1 deletion crates/burn-cubecl-fusion/src/matmul/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ impl<R: Runtime> MatmulBuilder<R> {
let settings = FuseSettings {
broadcast: true,
output_shape_updates: false,
mix_vectorization: true,
inplace: true,
};

Expand Down
26 changes: 14 additions & 12 deletions crates/burn-cubecl-fusion/src/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ use burn_fusion::stream::Context;
use burn_ir::{BinaryOpIr, TensorStatus};
use cubecl::linalg::matmul::components;
use cubecl::linalg::matmul::components::tile::accelerated::Accelerated;
use cubecl::linalg::matmul::components::tile::TileMatmulFamily;
use cubecl::linalg::matmul::components::MatmulProblem;
use cubecl::linalg::matmul::kernels::matmul::{
DoubleBufferingSelector, MatmulSelector, SimpleSelector, SpecializedSelector,
};
use cubecl::linalg::matmul::kernels::matmul::double_buffering::DoubleBufferingAlgorithm;
use cubecl::linalg::matmul::kernels::matmul::simple::SimpleAlgorithm;
use cubecl::linalg::matmul::kernels::matmul::specialized::SpecializedAlgorithm;
use cubecl::linalg::matmul::kernels::matmul::{select_kernel, Algorithm};
use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError};
use cubecl::linalg::tensor::{matrix_layout, MatrixLayout};
use cubecl::{client::ComputeClient, prelude::*};
Expand Down Expand Up @@ -351,7 +353,7 @@ impl FusedMatmul {

match self.selector {
FusedMatmulSelector::Simple => {
match matmul_launch_kernel::<R, EG, SimpleSelector<Accelerated>>(
match matmul_launch_kernel::<R, EG, SimpleAlgorithm<Accelerated>>(
client,
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
outputs,
Expand All @@ -363,7 +365,7 @@ impl FusedMatmul {
}
}
FusedMatmulSelector::DoubleBuffering => {
match matmul_launch_kernel::<R, EG, DoubleBufferingSelector<Accelerated>>(
match matmul_launch_kernel::<R, EG, DoubleBufferingAlgorithm<Accelerated>>(
client,
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
outputs,
Expand All @@ -375,7 +377,7 @@ impl FusedMatmul {
}
}
FusedMatmulSelector::Specialized => {
match matmul_launch_kernel::<R, EG, SpecializedSelector<Accelerated>>(
match matmul_launch_kernel::<R, EG, SpecializedAlgorithm<Accelerated>>(
client,
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
outputs,
Expand All @@ -390,7 +392,7 @@ impl FusedMatmul {
}
}

fn matmul_launch_kernel<'a, R: Runtime, EG: Numeric, S: MatmulSelector>(
fn matmul_launch_kernel<'a, R: Runtime, EG: Numeric, A: Algorithm>(
client: &ComputeClient<R::Server, R::Channel>,
input: FusedMatmulInputLaunch<'a, R>,
output: GlobalArgsLaunch<'a, R>,
Expand All @@ -400,19 +402,19 @@ fn matmul_launch_kernel<'a, R: Runtime, EG: Numeric, S: MatmulSelector>(
if TypeId::of::<EG>() == TypeId::of::<half::f16>()
|| TypeId::of::<EG>() == TypeId::of::<flex32>()
{
S::select_kernel::<FusedMatmulSpec<EG, half::f16, f32>, R>(
select_kernel::<FusedMatmulSpec<EG, half::f16, f32>, R, A>(
client, input, output, problem, plane_size, false,
)
} else if TypeId::of::<EG>() == TypeId::of::<half::bf16>() {
S::select_kernel::<FusedMatmulSpec<EG, half::bf16, f32>, R>(
select_kernel::<FusedMatmulSpec<EG, half::bf16, f32>, R, A>(
client, input, output, problem, plane_size, false,
)
} else if S::stage_tf32_supported() {
S::select_kernel::<FusedMatmulSpec<EG, tf32, f32>, R>(
} else if <A::TileMatmul as TileMatmulFamily>::requires_tensor_cores() {
select_kernel::<FusedMatmulSpec<EG, tf32, f32>, R, A>(
client, input, output, problem, plane_size, false,
)
} else {
S::select_kernel::<FusedMatmulSpec<EG, EG, f32>, R>(
select_kernel::<FusedMatmulSpec<EG, EG, f32>, R, A>(
client, input, output, problem, plane_size, false,
)
}
Expand Down
Loading