diff --git a/Cargo.toml b/Cargo.toml index 38f3e26..c01faac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,10 @@ resolver = "2" clrt = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" } search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" } -infini-rt = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "f40bcb5" } -infini-op = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "f40bcb5" } -infini-ccl = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "f40bcb5" } -search-infini-tools = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "f40bcb5" } +infini-rt = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "29fd321" } +infini-op = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "29fd321" } +infini-ccl = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "29fd321" } +search-infini-tools = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "29fd321" } cuda = { git = "https://github.com/YdrMaster/cuda-driver.git", rev = "b064bfb" } cublas = { git = "https://github.com/YdrMaster/cuda-driver.git", rev = "b064bfb" } diff --git a/operators/src/all_reduce/infini.rs b/operators/src/all_reduce/infini.rs index 87784dd..163b45f 100644 --- a/operators/src/all_reduce/infini.rs +++ b/operators/src/all_reduce/infini.rs @@ -1,7 +1,8 @@ use super::{args::Meta, AllReduce, Args, ReduceOp}; use crate::{ infini::{Device, InfiniNode}, - rearrange, ByteOf, LaunchError, QueueAlloc, SchemeError, + rearrange::{self, infini::Operator as Rearrange}, + ByteOf, LaunchError, QueueAlloc, SchemeError, }; use digit_layout::types as ty; use infini_ccl::bindings::InfiniDataType_t; @@ -10,8 +11,9 @@ use std::{ sync::Arc, }; -pub struct Operator { - comm: Arc, +pub enum Operator { + Rearrange(Rearrange), + Comm(Arc), } impl AllReduce for Operator {} @@ -22,8 +24,9 @@ impl crate::Operator for Operator { type Args = Args; fn new(node: &Self::TopoNode) -> Self { - Self { - comm: node.comm.clone(), + match node.comm.as_ref() { + Some(comm) => Self::Comm(comm.clone()), + None => Self::Rearrange(Rearrange::new(&node.device)), } } @@ -38,33 +41,40 @@ impl crate::Operator for Operator { fn launch( &self, args: &Self::Args, - _workspace: &mut [ByteOf], + workspace: &mut [ByteOf], queue_alloc: &QA, ) -> Result<(), LaunchError> where QA: QueueAlloc, { - let Meta { dt, size } = args.meta()?; - let &Args { - pair: rearrange::Args { - dst_base, src_base, .. - }, - op, - .. - } = args; + match self { + Self::Rearrange(rearrange) => rearrange.launch(&args.pair, workspace, queue_alloc), + Self::Comm(comm) => { + let Meta { dt, size } = args.meta()?; + let &Args { + pair: + rearrange::Args { + dst_base, src_base, .. + }, + op, + .. + } = args; - assert_eq!(op, ReduceOp::Sum); - let len = dt.nbytes() * size; - self.comm.allreduce_sum( - unsafe { from_raw_parts_mut(dst_base, len) }, - unsafe { from_raw_parts(src_base, len) }, - match dt { - ty::F16 => InfiniDataType_t::INFINI_F16, - ty::F32 => InfiniDataType_t::INFINI_F32, - _ => todo!(), - }, - queue_alloc.queue(), - ); - Ok(()) + assert_eq!(op, ReduceOp::Sum); + let len = dt.nbytes() * size; + + comm.allreduce_sum( + unsafe { from_raw_parts_mut(dst_base, len) }, + unsafe { from_raw_parts(src_base, len) }, + match dt { + ty::F16 => InfiniDataType_t::INFINI_F16, + ty::F32 => InfiniDataType_t::INFINI_F32, + _ => todo!(), + }, + queue_alloc.queue(), + ); + Ok(()) + } + } } } diff --git a/operators/src/handle/infini/ccl.rs b/operators/src/handle/infini/ccl.rs index 1e13b7b..9a89ac3 100644 --- a/operators/src/handle/infini/ccl.rs +++ b/operators/src/handle/infini/ccl.rs @@ -7,62 +7,49 @@ pub struct InfiniNode { rank: usize, group_size: usize, pub(crate) device: Device, - pub(crate) comm: Arc, + pub(crate) comm: Option>, } impl InfiniNode { pub fn cpu(n: usize) -> Vec { - let device = Device::cpu(); - let indices = (0..n as c_uint).collect::>(); - Comm::init_all(DeviceType::DEVICE_CPU, &indices) - .into_iter() - .enumerate() - .map(|(id, comm)| Self { - rank: id, - group_size: n, - device: device.clone(), - comm: Arc::new(comm), - }) - .collect() + let indices = (0..n as _).collect::>(); + Self::new(&indices, DeviceType::DEVICE_CPU) } pub fn nv_gpu(indices: &[c_uint]) -> Vec { - Comm::init_all(DeviceType::DEVICE_NVIDIA, indices) - .into_iter() - .enumerate() - .map(|(id, comm)| Self { - rank: id, - group_size: indices.len(), - device: Device::nv_gpu(id), - comm: Arc::new(comm), - }) - .collect() + Self::new(indices, DeviceType::DEVICE_NVIDIA) } pub fn cambricon_mlu(indices: &[c_uint]) -> Vec { - Comm::init_all(DeviceType::DEVICE_CAMBRICON, indices) - .into_iter() - .enumerate() - .map(|(id, comm)| Self { - rank: id, - group_size: indices.len(), - device: Device::cambricon_mlu(id), - comm: Arc::new(comm), - }) - .collect() + Self::new(indices, DeviceType::DEVICE_CAMBRICON) } pub fn ascend_npu(indices: &[c_uint]) -> Vec { - Comm::init_all(DeviceType::DEVICE_ASCEND, indices) - .into_iter() - .enumerate() - .map(|(id, comm)| Self { - rank: id, - group_size: indices.len(), - device: Device::ascend_npu(id), - comm: Arc::new(comm), - }) - .collect() + Self::new(indices, DeviceType::DEVICE_ASCEND) + } + + fn new(indices: &[c_uint], ty: DeviceType) -> Vec { + let confused = unsafe { std::mem::transmute(ty) }; + if let &[id] = indices { + vec![Self { + rank: 0, + group_size: 1, + device: Device::new(confused, id as _), + comm: None, + }] + } else { + Comm::init_all(ty, indices) + .into_iter() + .zip(indices) + .enumerate() + .map(|(idx, (comm, &id))| Self { + rank: idx, + group_size: indices.len(), + device: Device::new(confused, id as _), + comm: Some(Arc::new(comm)), + }) + .collect() + } } } diff --git a/operators/src/handle/infini/mod.rs b/operators/src/handle/infini/mod.rs index 5aac8d2..622f031 100644 --- a/operators/src/handle/infini/mod.rs +++ b/operators/src/handle/infini/mod.rs @@ -1,8 +1,5 @@ use crate::{Alloc, Hardware, QueueAlloc, QueueOf}; -use infini_rt::{ - DevBlob, DevByte, DeviceType, Stream, DEVICE_ASCEND, DEVICE_CAMBRICON, DEVICE_CPU, - DEVICE_NVIDIA, -}; +use infini_rt::{DevBlob, DevByte, DeviceType, Stream}; use std::{ops::Deref, sync::Arc}; mod ccl; @@ -17,34 +14,34 @@ pub struct Device { impl Device { #[inline] pub fn cpu() -> Self { - Self::new(DEVICE_CPU, 0) + Self::new(infini_rt::DEVICE_CPU, 0) } #[inline] pub fn nv_gpu(id: usize) -> Self { - Self::new(DEVICE_NVIDIA, id) + Self::new(infini_rt::DEVICE_NVIDIA, id) } #[inline] pub fn cambricon_mlu(id: usize) -> Self { - Self::new(DEVICE_CAMBRICON, id) + Self::new(infini_rt::DEVICE_CAMBRICON, id) } #[inline] pub fn ascend_npu(id: usize) -> Self { - Self::new(DEVICE_ASCEND, id) + Self::new(infini_rt::DEVICE_ASCEND, id) } fn new(ty: infini_rt::DeviceType, id: usize) -> Self { - use infini_op::bindings::Device::*; + use infini_op::bindings::Device as Ty; Self { device: infini_rt::Device { ty, id: id as _ }, handle: Arc::new(infini_op::Handle::new( match ty { - DEVICE_CPU => DevCpu, - DEVICE_NVIDIA => DevNvGpu, - DEVICE_CAMBRICON => DevCambriconMlu, - DEVICE_ASCEND => DevAscendNpu, + infini_rt::DEVICE_CPU => Ty::DevCpu, + infini_rt::DEVICE_NVIDIA => Ty::DevNvGpu, + infini_rt::DEVICE_CAMBRICON => Ty::DevCambriconMlu, + infini_rt::DEVICE_ASCEND => Ty::DevAscendNpu, _ => unreachable!("unknown device type"), }, id as _,