-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Prototype for converting triton
to linalg
#1797
Conversation
afef67e
to
cce4d16
Compare
cce4d16
to
70e6113
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking the merge for now, as this is not clear where this should live in the long term
Thanks for all this work! I think that this could be very useful for people who want to hook up Triton to existing Linalg-based compiler infra! |
Just want to say this is awesome work @nhat-nguyen to bridge into existing Linalg-based compilers like SHARK/IREE. |
Awesome work! We have also internally developed a Triton To Linalg conversion. The RFC is proposed in #1542, there are some differences in the implementation, but I believe there will be opportunities for collaboration in this process to jointly improve the entire mechanism. |
Hello, In PtrAnalysis we encountered a case where arith.select is the parent of tt.addptr and the PtrAnalysis::visitOperand failed to match. This case will possibly produce non-continuous memory accesses and thus memref/linalg fails to capture. We hope to be inspired by you. Thanks.
|
@yuanfz98 Sorry for the delayed response. Would you mind sharing the triton program that produces this error? We don't currently support non-contiguous memory loads so this could be the reason why. |
After internal discussion, we have decided that it would make more sense for this useful utility to live out-of-tree, where other out-of-tree third party backends could also use it to build their binaries |
PR redone as a plug-in: #2374 |
After reducing the number of runners some workflows are no longer working. Using new labels to select the remaining runners. Relates to triton-lang#1793.
Prototype for converting
triton
tolinalg
Introduction
I would like to share with the
triton
community this on-going work to add support for converting the triton dialect to the linalg dialect. This PR, which includes contributions from myself and my peers at Microsoft, introduces a prototype that supportslinalg
conversion for a limited number of scenarios. We hope folks will eventually be able to leverage this work in building new back-ends fortriton
.Usage
The pass is exposed via
triton-opt
and can be invoked by using the--triton-to-linalg
flag like so:Impact to the triton compiler
This PR introduces the
TritonToLinalg
pass and its related analyses. The new code is accessible only viatriton-opt
; it does not introduce any changes to the triton main compilation path.Implementation details
Even though a valid triton program can perform load and store in arbitrary memory locations, the prototype only supports lowering programs that have structured memory access patterns.
Analyses
As part of the conversion process, there are three important analyses:
Pointer analysis:
triton
program during load and store; it walks the IR and visits relevant instructions to build strided memory accesses in thememref
dialect. The analysis is still in its early stage and does not support all scenarios.Use analysis:
memref
operations representing strided memory accesses. To aid with removing these instructions safely, we performUse analysis
to mark which instructions are used only in address calculation (calledMetaUse
) or used in both address calculation and data manipulation (calledMixedUse
) operations. Those that areMixedUse
are cloned and have their users adjusted accordingly with the goal of separating out theMetaUse
ops so that they can be safely deleted.Mask analysis:
Conversion strategy
We introduce the
TritonToLinalg
pass that converts thetriton
dialect to thelinalg
dialect on tensors. This means the resulting IR is fully compatible withlinalg
tiling and fusion transformation passes. As mentioned in thePointer analysis
's description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors usingbufferization.to_tensor
. Here's a simple example of what the IR looks like:after conversion:
Important details to note:
tt.load
(together with all of its related address calculation instructions such astt.addptr
andtt.splat
) are lowered to a combination ofmemref.reinterpret_cast
,memref.alloc
, andmemref.copy
. After the initialization of the local buffer, we convert the memref back to a tensor usingbufferization.to_tensor
; this op is automatically removed during bufferization.tt.store
lowers to a combination ofmemref.reinterpret_cast
and eitheraffine.store
ormemref.tensor_store
:arith
andmath
operators are converted to their correspondinglinalg.generic
version.tt.dot
becomeslinalg.matmul
.tt.reduce
becomeslinalg.reduce
; known limitation: only supportaddf
andmaxf
reduction in the reduction body for now.Testing
The prototype was tested on the following triton kernel examples:
In addition to testing on the tutorial kernels, I have also added many lit tests covering various scenarios.