Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Multidimensional device mesh #3821

Merged
merged 7 commits into from
Feb 8, 2025
Merged

Conversation

cowanmeg
Copy link
Collaborator

@cowanmeg cowanmeg commented Feb 4, 2025

  1. Extends DeviceMesh to support multidimensional meshes. Representation is a flat vector of devices and a vector of shapes. Since we only have plans to support DIDx, DIDy, and DIDz this naturally restricts the max dimensions to 3D.
  2. Key functionality is in getSlice(DeviceIdxType, ParallelType) which gives a set of devices that a tensor is sharded over and communicates with given the (device) parallel type which translates into a dimension into the mesh.
  3. Trivially modifies communication lowering to analyze a slice instead of the entire mesh for collectives that use the same mesh (AllReduce, ReduceScatter, AllGather). Collectives without this property assert that it is operating over a 1D mesh.
  4. Adds ParallelType's DIDy, DIDz. These are not used anywhere except to index into 2D and 3D device mesh tests.

@cowanmeg cowanmeg marked this pull request as draft February 4, 2025 22:04
@cowanmeg cowanmeg changed the title Multi-dimension device mesh Multidimensional device mesh Feb 4, 2025
@cowanmeg cowanmeg marked this pull request as ready for review February 6, 2025 20:42
@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Feb 6, 2025

!build

@cowanmeg cowanmeg requested a review from wujingyue February 6, 2025 20:54

static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);
DeviceIdxType my_device_index);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: will lowering have to take the current GPU ID? I understand 2D meshes lead to different teams so host IR as is has to be different. An alternative would be for each Communication to take teams (a vector of vectors which include all GPUs) rather than team (the GPUs in the team that the current GPU is in). Only at runtime, each device looks up which team in teams it's in. This way, all ranks have the same host IR (except for pipeline parallelism) so we can distribute the lowering on all ranks for compilation speed.

cc @samnordmann

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowering does not have to take the GPU IDs, but more a matter of where to shift to overhead from lookup and what parallelism strategies you expect to support.
If you expect teams to be dynamic throughout training, basically dynamic meshes for each run, obviously you only know the team at runtime. However, if we expect one training job to keep the mesh static, then moving the creation into compile time and push any overheads into a one time cost.

The vector of teams is interesting, it does save lowering time since it pushes that cost into run time. Run time lookup is probably not an issue since the mesh's will only be within a node so O10s device.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: will lowering have to take the current GPU ID? I understand 2D meshes lead to different teams so host IR as is has to be different. An alternative would be for each Communication to take teams

Yes, sounds good. An alternative would be that the Communication only takes one team, but Team inherits from Val* and is bound to concrete value through expr_evalutator_. That Team could be defined through an Expr* that indicates taking the mesh's slice over some axis at my_device_id. This way, the Communication IR is symmetric across all ranks, and it covers both static and dynamic case.

@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Feb 7, 2025

!test

@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Feb 7, 2025

!build

@wujingyue
Copy link
Collaborator

cc @xwang233 there's apparently a bug in the PR agent tool

@wujingyue
Copy link
Collaborator

!test

@xwang233
Copy link
Collaborator

xwang233 commented Feb 8, 2025

cc @xwang233 there's apparently a bug in the PR agent tool

For safety reasons, GitHub actions triggered from forked repos (not nvidia/fuser but another_user/fuser) cannot see action secrets, where LLM API keys are stored. Thus, no reviews can be generated.

I would recommend @cowanmeg to directly create branch on this repo for future PRs. 😉

@wujingyue
Copy link
Collaborator

I would recommend @cowanmeg to directly create branch on this repo for future PRs.

Got it! @cowanmeg, did you run into permission issues when pushing branches to nvFuser? I don't remember what it takes.

@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Feb 8, 2025

Hmm the option to squash and merge is coming up. I can also just push this directly to a branch on Fuser and open a new PR so that the PR agent tools can run? We can use this PR for the unit tests at least

@wujingyue
Copy link
Collaborator

I didn't mean to ask you wait for PR agent -- it's certainly optional. I was probably unclear, sorry.

I meant to ask you check whether you are able to create new branches in NVIDIA/Fuser. If yes, I'd recommend do that for your future PRs so the PR agent can kick and you can stack your PRs. If no, @xwang233 and I will try to figure out why because you are apparently a "collaborator" of NVIDIA/Fuser already.

@wujingyue wujingyue merged commit 19c46bf into NVIDIA:main Feb 8, 2025
40 of 43 checks passed
@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Feb 8, 2025

I can make branches on Fuser! I am just used to working on a fork so defaulted to that! Thanks!

wujingyue added a commit that referenced this pull request Feb 8, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants