Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Prototype for converting triton to linalg #1797

Closed
wants to merge 1 commit into from
Closed

Conversation

nhat-nguyen
Copy link
Collaborator

Prototype for converting triton to linalg

Introduction

I would like to share with the triton community this on-going work to add support for converting the triton dialect to the linalg dialect. This PR, which includes contributions from myself and my peers at Microsoft, introduces a prototype that supports linalg conversion for a limited number of scenarios. We hope folks will eventually be able to leverage this work in building new back-ends for triton.

Usage

The pass is exposed via triton-opt and can be invoked by using the --triton-to-linalg flag like so:

triton-opt --triton-to-linalg %file

Impact to the triton compiler

This PR introduces the TritonToLinalg pass and its related analyses. The new code is accessible only via triton-opt; it does not introduce any changes to the triton main compilation path.

Implementation details

Even though a valid triton program can perform load and store in arbitrary memory locations, the prototype only supports lowering programs that have structured memory access patterns.

Analyses

As part of the conversion process, there are three important analyses:

  1. Pointer analysis:

    • This analysis is responsible for extracting structured memory access patterns from a triton program during load and store; it walks the IR and visits relevant instructions to build strided memory accesses in the memref dialect. The analysis is still in its early stage and does not support all scenarios.
  2. Use analysis:

    • After "Pointer analysis", instructions that are part of memory address calculation will no longer be necessary in a triton program because their semantics have now been captured by memref operations representing strided memory accesses. To aid with removing these instructions safely, we perform Use analysis to mark which instructions are used only in address calculation (called MetaUse) or used in both address calculation and data manipulation (called MixedUse) operations. Those that are MixedUse are cloned and have their users adjusted accordingly with the goal of separating out the MetaUse ops so that they can be safely deleted.
  3. Mask analysis:

    • This analysis is responsible for handling masked loads and stores.

Conversion strategy

We introduce the TritonToLinalg pass that converts the triton dialect to the linalg dialect on tensors. This means the resulting IR is fully compatible with linalg tiling and fusion transformation passes. As mentioned in the Pointer analysis's description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using bufferization.to_tensor. Here's a simple example of what the IR looks like:

tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
  %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
  %1 = tt.splat %afloat : (!tt.ptr<bf16>) -> tensor<128x!tt.ptr<bf16>>
  %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<bf16>>, tensor<128xi32>
  %afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16>
  %3 = "tt.reduce"(%afm) ({
  ^bb0(%arg5: bf16, %arg6: bf16):
    %21 = arith.addf %arg5, %arg6 : bf16
    tt.reduce.return %21 : bf16
  }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16
  tt.store %res, %3 : bf16
  tt.return
}

after conversion:

func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) {
    %cst = arith.constant 0.000000e+00 : f32
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] :
        memref<*xbf16> to memref<128xbf16, strided<[1]>>
    %alloc = memref.alloc() : memref<128xbf16>
    memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16>
    %0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16>
    %1 = bufferization.alloc_tensor() : tensor<f32>
    %inserted = tensor.insert %cst into %1[] : tensor<f32>
    %reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor<f32>) dimensions = [0]
      (%in: bf16, %init: f32) {
        %3 = arith.extf %in : bf16 to f32
        %4 = arith.addf %3, %init : f32
        linalg.yield %4 : f32
      }
    %extracted = tensor.extract %reduced[] : tensor<f32>
    %2 = arith.truncf %extracted : f32 to bf16
    %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] :
        memref<*xbf16> to memref<1xbf16, strided<[1]>>
    affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>>
    return

}

Important details to note:

  • tt.load (together with all of its related address calculation instructions such as tt.addptr and tt.splat) are lowered to a combination of memref.reinterpret_cast, memref.alloc, and memref.copy. After the initialization of the local buffer, we convert the memref back to a tensor using bufferization.to_tensor; this op is automatically removed during bufferization.

  • tt.store lowers to a combination of memref.reinterpret_cast and either affine.store or memref.tensor_store:

%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
memref.tensor_store %extracted_slice, %subview : memref<?xf32>
  • element-wise arith and math operators are converted to their corresponding linalg.generic version.
  • tt.dot becomes linalg.matmul.
  • tt.reduce becomes linalg.reduce; known limitation: only support addf and maxf reduction in the reduction body for now.

Testing

The prototype was tested on the following triton kernel examples:

  1. vector addition
  2. fused softmax
  3. matrix multiplication
  4. layer normalization
  5. fused attention

In addition to testing on the tutorial kernels, I have also added many lit tests covering various scenarios.

Copy link
Collaborator

@ptillet ptillet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking the merge for now, as this is not clear where this should live in the long term

@ptillet
Copy link
Collaborator

ptillet commented Jun 16, 2023

Thanks for all this work! I think that this could be very useful for people who want to hook up Triton to existing Linalg-based compiler infra!

@powderluv
Copy link

Just want to say this is awesome work @nhat-nguyen to bridge into existing Linalg-based compilers like SHARK/IREE.

@sethbrin
Copy link

sethbrin commented Jun 21, 2023

Awesome work!

We have also internally developed a Triton To Linalg conversion. The RFC is proposed in #1542, there are some differences in the implementation, but I believe there will be opportunities for collaboration in this process to jointly improve the entire mechanism.

@yuanfz98
Copy link

Hello,

In PtrAnalysis we encountered a case where arith.select is the parent of tt.addptr and the PtrAnalysis::visitOperand failed to match. This case will possibly produce non-continuous memory accesses and thus memref/linalg fails to capture. We hope to be inspired by you. Thanks.

    %18:2 = scf.for %arg8 = %c0_i32 to %c8192_i32 step %c2048_i32 iter_args(%arg9 = %cst_7, %arg10 = %cst_6) -> (tensor<1x2048xf32>, tensor<1x2048xi64>)  : i32 {
      %27 = tt.splat %arg8 : (i32) -> tensor<1x2048xi32>
      %28 = arith.addi %27, %7 : tensor<1x2048xi32>
      %29 = arith.cmpi slt, %28, %cst_5 : tensor<1x2048xi32>
      %30 = arith.addi %28, %9 : tensor<1x2048xi32>
      %31 = tt.addptr %10, %30 : tensor<1x2048x!tt.ptr<i64>>, tensor<1x2048xi32>
      %32 = arith.andi %29, %11 : tensor<1x2048xi1>
      %33 = tt.load %31, %32, %cst_6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xi64>
      %34 = tt.addptr %12, %30 : tensor<1x2048x!tt.ptr<f32>>, tensor<1x2048xi32>
      %35 = tt.load %34, %32, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xf32>
      %36 = tt.addptr %13, %30 : tensor<1x2048x!tt.ptr<f32>>, tensor<1x2048xi32>
      %37 = tt.load %36, %32, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xf32>
      %38 = arith.cmpi ne, %33, %cst_3 : tensor<1x2048xi64>
      %39 = arith.select %38, %33, %cst_6 : tensor<1x2048xi1>, tensor<1x2048xi64>
      %40 = arith.cmpi sge, %39, %cst_6 : tensor<1x2048xi64>
      %41 = arith.cmpi slt, %39, %cst_0 : tensor<1x2048xi64>
      %42 = arith.andi %40, %41 : tensor<1x2048xi1>
      tt.assert %42, "index out of bounds: 0 <= tmp4 < 65", "<frozen importlib._bootstrap_external>", "_call_with_frames_removed", 883 : tensor<1x2048xi1>
      %43 = arith.muli %28, %cst_2 : tensor<1x2048xi32>
      %44 = arith.extsi %43 : tensor<1x2048xi32> to tensor<1x2048xi64>
      %45 = arith.addi %39, %44 : tensor<1x2048xi64>
      %46 = arith.addi %45, %16 : tensor<1x2048xi64>
      %47 = tt.addptr %17, %46 : tensor<1x2048x!tt.ptr<bf16>>, tensor<1x2048xi64>
...
%204 = "arith.select"(%200, %95, %19) {MetaUse} : (tensor<1x2048xi1>, tensor<1x2048xi64>, tensor<1x2048xi64>) -> tensor<1x2048xi64>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /wkspc/hongjing/triton/lib/Analysis/PtrAnalysis.cpp:377!

@nhat-nguyen
Copy link
Collaborator Author

Hello,

In PtrAnalysis we encountered a case where arith.select is the parent of tt.addptr and the PtrAnalysis::visitOperand failed to match. This case will possibly produce non-continuous memory accesses and thus memref/linalg fails to capture. We hope to be inspired by you. Thanks.

    %18:2 = scf.for %arg8 = %c0_i32 to %c8192_i32 step %c2048_i32 iter_args(%arg9 = %cst_7, %arg10 = %cst_6) -> (tensor<1x2048xf32>, tensor<1x2048xi64>)  : i32 {
      %27 = tt.splat %arg8 : (i32) -> tensor<1x2048xi32>
      %28 = arith.addi %27, %7 : tensor<1x2048xi32>
      %29 = arith.cmpi slt, %28, %cst_5 : tensor<1x2048xi32>
      %30 = arith.addi %28, %9 : tensor<1x2048xi32>
      %31 = tt.addptr %10, %30 : tensor<1x2048x!tt.ptr<i64>>, tensor<1x2048xi32>
      %32 = arith.andi %29, %11 : tensor<1x2048xi1>
      %33 = tt.load %31, %32, %cst_6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xi64>
      %34 = tt.addptr %12, %30 : tensor<1x2048x!tt.ptr<f32>>, tensor<1x2048xi32>
      %35 = tt.load %34, %32, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xf32>
      %36 = tt.addptr %13, %30 : tensor<1x2048x!tt.ptr<f32>>, tensor<1x2048xi32>
      %37 = tt.load %36, %32, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2048xf32>
      %38 = arith.cmpi ne, %33, %cst_3 : tensor<1x2048xi64>
      %39 = arith.select %38, %33, %cst_6 : tensor<1x2048xi1>, tensor<1x2048xi64>
      %40 = arith.cmpi sge, %39, %cst_6 : tensor<1x2048xi64>
      %41 = arith.cmpi slt, %39, %cst_0 : tensor<1x2048xi64>
      %42 = arith.andi %40, %41 : tensor<1x2048xi1>
      tt.assert %42, "index out of bounds: 0 <= tmp4 < 65", "<frozen importlib._bootstrap_external>", "_call_with_frames_removed", 883 : tensor<1x2048xi1>
      %43 = arith.muli %28, %cst_2 : tensor<1x2048xi32>
      %44 = arith.extsi %43 : tensor<1x2048xi32> to tensor<1x2048xi64>
      %45 = arith.addi %39, %44 : tensor<1x2048xi64>
      %46 = arith.addi %45, %16 : tensor<1x2048xi64>
      %47 = tt.addptr %17, %46 : tensor<1x2048x!tt.ptr<bf16>>, tensor<1x2048xi64>
...
%204 = "arith.select"(%200, %95, %19) {MetaUse} : (tensor<1x2048xi1>, tensor<1x2048xi64>, tensor<1x2048xi64>) -> tensor<1x2048xi64>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /wkspc/hongjing/triton/lib/Analysis/PtrAnalysis.cpp:377!

@yuanfz98 Sorry for the delayed response. Would you mind sharing the triton program that produces this error? We don't currently support non-contiguous memory loads so this could be the reason why.

@ptillet
Copy link
Collaborator

ptillet commented Sep 21, 2023

After internal discussion, we have decided that it would make more sense for this useful utility to live out-of-tree, where other out-of-tree third party backends could also use it to build their binaries

@ptillet ptillet closed this Sep 21, 2023
@manbearian
Copy link
Collaborator

PR redone as a plug-in: #2374

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 16, 2024
After reducing the number of runners some workflows are no longer
working.
Using new labels to select the remaining runners. Relates to triton-lang#1793.
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants