Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow separate sub-DAG for load and compute warp groups with warp-specialized circular buffering. #3941

Open
rdspring1 opened this issue Feb 21, 2025 · 0 comments
Assignees

Comments

@rdspring1
Copy link
Collaborator

rdspring1 commented Feb 21, 2025

Why? Support persistent and ping-pong matmul kernels.

Context:
Circular-buffer for-loop is the first serial for-loop to the left of computeAt position. It is applied to the load cacheAfter TensorViews. Persistent scheduling with matmul creates multiple serial for-loops. There is a grid-stride for-loop over output-tiles and cta-k for-loop.

Scheduling Proposal:

  • Merge the output-tile and cta-k iterDomains.
  • Apply circular buffering to the load cacheAfter TensorView.
  • wgmma consumer would have separate output-tile and cta-k iterDomains.

Why? The load cacheAfter TensorViews will have a single serial iterDomain for circular buffering. This matches current circular buffering implementation.

Problem: The output-tile and cta-k iterDomains cannot be merged for compute warp-groups because storing the matmul results to global memory does not have cta-k iterDomain. Therefore, the wgmma consumer cannot be inlined with cacheAfter tma load, breaking the current circular buffering implementation.

Does consumer of circular buffering inputs need to be inlined?
This restriction seems unnecessary for warp-specialized circular buffering.

Lowering Proposal: For warp-specialized circular buffering, track separate for-loops for load and compute warp groups.
Restriction: The compute for-loop must be derived from load for-loop.

Pseudo-code:

mbarrier init
if (tma-load) {
  decrease_register_limit(40);
  for (output-tile) {
    for (cta-k) {
      if (elect-sync) {
        mbarrier wait for empty stage
        mbarrier arriveExpectTx for tma load
        tma load operand A and B cta tiles for stage
      }
    }
  }
} else { compute warp-group
  increase_register_limit(232);
  mbarrier arrive to signal all stages are empty
  for (output-tile) {
    for (cta-k) {
      mbarrier wait for full stage
      for (warp-k) {
        wgmma_fence;
        wgmma_64m_256m_16k;
      }
      wgmma_commit;
      wgmma_wait;
      mbarrier arrive to signal current stage is empty
    }
    wgmma_wait;
    convert fp32 results to bf16
    stmatrix from registers to shared memory
    block_sync();
    tma store from shared to global memory
    tma_store_commit;
    tma_store_wait;
  }
}
destroy mbarrier
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

1 participant