Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Multidimensional mesh #3937

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Multidimensional mesh #3937

wants to merge 10 commits into from

Conversation

cowanmeg
Copy link
Collaborator

Redo of #3821

Copy link

github-actions bot commented Feb 21, 2025

Review updated until commit 3524dd3

Description

  • Added support for multidimensional device meshes

  • Enhanced HostIrLower::lower to accept my_device_idx

  • Introduced DeviceMesh shape and multidimensional slicing

  • Updated tests to cover multidimensional mesh scenarios


Changes walkthrough 📝

Relevant files
Enhancement
8 files
executor.cpp
Pass `my_device_idx` to `HostIrLower::lower`                         
+3/-1     
lower.cpp
Update communication lowering for multidimensional meshes
+51/-19 
device_mesh.cpp
Add shape and multidimensional slicing to `DeviceMesh`     
+118/-2 
utils.cpp
Update `requestedNumberOfDevices` to use `maxDeviceId`     
+1/-3     
type.cpp
Add `DIDy` and `DIDz` to `ParallelType`                                   
+6/-1     
lower.h
Update `HostIrLower::lower` signature                                       
+2/-2     
device_mesh.h
Update `DeviceMesh` to support multidimensional meshes     
+50/-9   
type.h
Add `DIDy` and `DIDz` to `ParallelType`                                   
+6/-2     
Tests
1 files
test_sharding.cpp
Add tests for multidimensional `DeviceMesh`                           
+27/-0   

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Performance Concern

The introduction of DeviceIdxType my_device_idx in multiple functions could have performance implications, especially if my_device_idx is not efficiently cached or reused. Ensure that the performance impact is minimal and that the benefits of supporting multidimensional meshes outweigh the potential overhead.

/*
Adds zero or multiple Gather communications to the vector 'comms'

Note that since the root of a Gather collective is a destination, we possibly
need multiple Gathers if the tensor is replicated in the receiver mesh.
*/
void lowerToGather(
    TensorView* input_tv,
    TensorView* output_tv,
    std::vector<Expr*>& comms) {
  // we create as many 'Gathers' as there are devices in the receiver mesh
  const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
  NVF_ERROR(
      sender_mesh.rank() == 1,
      "Currently only lower Gather on a 1D mesh. Given ",
      sender_mesh);
  for (auto root : output_tv->getDeviceMesh().vector()) {
    Team team = sender_mesh.vector();
    if (!sender_mesh.has(root)) {
      team.push_back(root);
    }
    comms.push_back(IrBuilder::create<Communication>(
        CommunicationType::Gather, output_tv, input_tv, team, root));
  }
}

// Add one or zero Allgather communication to the vector 'comms'
void lowerToAllgather(
    TensorView* input_tv,
    TensorView* output_tv,
    std::vector<Expr*>& comms,
    DeviceIdxType my_device_idx) {
  Team team =
      input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx);
  comms.push_back(IrBuilder::create<Communication>(
      CommunicationType::Allgather, output_tv, input_tv, team));
}

// Adds one or zero Broadcast communication to the vector 'comms'
void lowerToBroadcast(
    TensorView* input_tv,
    TensorView* output_tv,
    DeviceIdxType root,
    std::vector<Expr*>& comms) {
  const DeviceMesh& mesh = output_tv->getDeviceMesh();
  NVF_ERROR(
      mesh.rank() == 1, "Broadcast only supported a 1D mesh. Given ", mesh);
  Team team = mesh.vector();
  if (!mesh.has(root)) {
    team.push_back(root);
  }
  comms.push_back(IrBuilder::create<Communication>(
      CommunicationType::Broadcast, output_tv, input_tv, team, root));
Error Handling

The error handling in the DeviceMesh constructor and methods like getSlice and getIndices is robust, but ensure that all edge cases are covered, especially with multidimensional meshes. Consider adding more test cases to validate these methods.

DeviceMesh::DeviceMesh(
    std::vector<DeviceIdxType> devices,
    std::vector<int64_t> shape) {
  setDevices(std::move(devices));
  if (shape.empty()) {
    shape = {(int64_t)vector_.size()};
  } else {
    int64_t num_devices =
        std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
    NVF_ERROR(
        (int64_t)vector_.size() == num_devices,
        "Specified a list of device with ",
        vector_.size(),
        " elements ",
        " but shape contains ",
        num_devices);
  }
  shape_ = std::move(shape);
}

DeviceMesh::DeviceMesh(std::initializer_list<DeviceIdxType> devices) {
  setDevices(std::vector<DeviceIdxType>(devices));
  shape_ = {(int64_t)vector_.size()};
}

void DeviceMesh::setDevices(std::vector<DeviceIdxType> devices) {
  vector_ = std::move(devices);

  std::unordered_set<DeviceIdxType> unique_devices(
      vector_.begin(), vector_.end());
  NVF_ERROR(
      unique_devices.size() == vector_.size(),
      "Device mesh has duplicates: ",
      vector_);
}

/*static*/ DeviceMesh DeviceMesh::createForNumDevices(
    const int64_t num_devices) {
  std::vector<DeviceIdxType> devices(num_devices);
  std::iota(devices.begin(), devices.end(), 0);
  return DeviceMesh(devices);
}

/*static*/ DeviceMesh DeviceMesh::createForShape(
    const std::vector<int64_t>& shape) {
  int64_t num_devices =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
  std::vector<DeviceIdxType> devices(num_devices);
  std::iota(devices.begin(), devices.end(), 0);
  return DeviceMesh(devices, shape);
}

std::ostream& operator<<(std::ostream& out, const DeviceMesh& mesh) {
  out << "DeviceMesh";
  int64_t ndevices = std::accumulate(
      mesh.shape().begin(), mesh.shape().end(), 1, std::multiplies<>());
  int64_t ndims = mesh.rank();
  std::vector<int64_t> strides = mesh.shape();
  for (auto i = ndims - 2; i >= 0; --i) {
    strides[i] *= strides[i + 1];
  }

  for (auto i = 0; i < ndevices; i++) {
    for (auto axis = 0; axis < ndims; axis++) {
      if (i % strides[axis] == 0) {
        out << "{";
      }
    }
    out << mesh.vector().at(i);
    if ((i + 1) % strides[ndims - 1] != 0) {
      out << " ";
    }
    for (auto axis = 0; axis < ndims; axis++) {
      if ((i + 1) % strides[axis] == 0) {
        out << "}";
      }
    }
  }

  return out;
}

int64_t DeviceMesh::size(const ParallelType parallel_type) const {
  NVF_ERROR(
      parallel_type == ParallelType::DIDx,
      "We support only 1-D sharding for now.");
  return size();
}

std::vector<int64_t> DeviceMesh::getIndices(const DeviceIdxType device) const {
  auto global_idx = idxOf(device);
  if (global_idx == -1) {
    return {};
  }
  std::vector<int64_t> indices(shape_.size());
  int64_t accumulated_size = 1;
  for (int64_t i = (int64_t)shape_.size() - 1; i >= 0; i--) {
    indices[i] = (global_idx / accumulated_size) % shape_[i];
    accumulated_size *= shape_[i];
  }
  return indices;
}

DeviceIdxType DeviceMesh::maxDeviceId() const {
  return *std::max_element(vector_.begin(), vector_.end());
}

namespace {
int64_t ptypeToAxis(ParallelType ptype, int64_t ndims) {
  NVF_ERROR(
      isParallelTypeDeviceDim(ptype),
      "Attempting to index into DeviceMesh with a non-device parallel type",
      ptype);
  int64_t offset =
      static_cast<int64_t>(ptype) - static_cast<int64_t>(ParallelType::DIDx);

  NVF_ERROR(
      offset < ndims,
      "DeviceMesh has ",
      ndims,
      " dimensions, but requesting slice for ",
      ptype);
  return ndims - 1 - offset;
}
} // namespace

std::vector<DeviceIdxType> DeviceMesh::getSlice(
    DeviceIdxType deviceId,
    ParallelType ptype) const {
  int64_t axis = ptypeToAxis(ptype, rank());
  auto indices = getIndices(deviceId);
  NVF_ERROR(
      !indices.empty(), "Device ", deviceId, " is not in DeviceMesh ", vector_);

  int64_t offset = 0;
  int64_t stride = 1;
  int64_t accumulated_size = 1;
  for (auto i = rank() - 1; i >= 0; i--) {
    if (i > axis) {
      stride *= shape_[i];
    }
    if (i != axis) {
      offset += indices[i] * accumulated_size;
    }
    accumulated_size *= shape_[i];
  }

  std::vector<DeviceIdxType> devices(shape_[axis]);
  for (auto i : c10::irange(devices.size())) {
    devices.at(i) = vector_.at(i * stride + offset);
  }
  return devices;
}
Documentation

The changes to the HostIrLower class introduce new parameters and methods. Ensure that the updated class documentation clearly explains the purpose and usage of these changes, especially the new DeviceIdxType my_device_index parameter.

class HostIrLower {
 public:
  // The flag `ignore_inner_resharding` is useful because the preseg passes
  // `InsertReshardingsPass` and `ReorderShardedAxisPass` want different
  // behaviors
  static bool canLower(Expr* expr, bool ignore_inner_resharding = false);

  // Lower a sharded Expr into a series of Communication.
  static std::vector<Expr*> lower(Expr* c, DeviceIdxType my_device_index);

  static std::unique_ptr<hir::HostIrContainer> lower(
      std::unique_ptr<Fusion> fusion,
      DeviceIdxType my_device_index);

@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator Author

!test

@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator Author

!test

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant