Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Add matmul API #629

Draft
wants to merge 35 commits into
base: master
Choose a base branch
from

Conversation

thomasfaingnaert
Copy link
Member

This PR contains an initial implementation of (my proposal for) an API to instantiate flexible matrix multiplication kernels.
It is divided in two large parts:

  • A Tiling API that aims to make recursively subdividing matrices (or tensors in general) easier (src/device/tiling.jl)
  • The API for matrix multiplication itself, which uses the tiling API (src/device/matmul_kernels*)

The matmul API itself consists of several components, which allow the user to customise the behaviour of the GEMM:

  • config.jl: This file defines the Config type that allows the user to customise the parameters of the matmul. A helper function get_config allows creating this Config easily, and additionally includes some heuristics to set default values for parameters the user does not specify. Note that the tiling sizes are specified in a "logical" coordinate space, i.e. the precise meaning is up to the user.
  • layout.jl: Layouts determine how the logical coordinates are converted to physical offsets in memory. Each matrix (A, B, C, D) can have a different layout in both global and shared memory.
  • transform.jl: Transforms are applied after every load, and before every store. They are essentially functors that are baked into the memory stream from global to shared memory and from shared memory to registers (and vice versa for D, obviously). The most obvious use case here is elementwise transforms (be it scaling or activation functions in neural nets).
  • operator.jl: Operators define the computation performed in the inner loop of the GEMM, and how it is performed.
  • epilogue.jl: Epilogues define what happens at the last step of the GEMM. At that point, each CTA has a tile of the resultant matrix in shared memory. The default epilogue just stores this tile to global memory, but other epilogues may perform more complex operations, such as reductions across thread blocks.
  • kernel.jl: The implementation of the matrix multiplication kernel itself. It uses the abstractions described above, and the Tiling API.

At the moment, only the components needed for a mixed-precision GEMM using WMMA is implemented (about 1 or 2 components per abstraction).
For M = N = K = 2048, the Julia implementation takes about 536 us, compared to cuBLAS's 440 us (turing_s1688gemm_fp16_128x256_ldg8_nn), resulting in a performance of about 82% that of cuBLAS.

As a final note, I have mainly been testing this on Julia v1.5.0-DEV-324 (LLVM 9.0.1).
While the matmul still works in Julia 1.4.1 (LLVM 8.0.1), I've noticed a reduction in performance, which seems to be mainly caused by the @unroll fors not being unrolled.

@thomasfaingnaert thomasfaingnaert marked this pull request as draft April 25, 2020 17:20
# for free to subscribe to this conversation on GitHub. Already have an account? #.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant