Skip to content

Add MIOpen #320

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 25 commits into from
Nov 14, 2022
Merged

Add MIOpen #320

merged 25 commits into from
Nov 14, 2022

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Nov 2, 2022

This PR adds initial support for MIOpen.
Mainly it will target only convolutions for now.

  • Add auto-generated MIOpen wrapper.
  • Add tensor/convolution descriptors.
  • Implement high-level convolution functions.
  • Implement benchmark caching.
  • Add MIOpen tests.

Example of convolution API:

x = AMDGPU.ones(Float32, 3, 3, 1, 1)
w = AMDGPU.ones(Float32, 3, 3, 1, 1)
y = MIOpen.convolution(x, w; padding=(0, 0), stride=(1, 1), dilation=(1, 1), groups=1)

TODO

  • Complete convolution wrappers.
  • Add more tests.
  • Update docs.
  • Investigate the source of errors when following env variables are not set: MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0, MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0, MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0.
    Potential fix is to rebuild HIP (hipamd) with PCH on.
    UPD: Should be fixed with this PR: Disable PCH completely for MIOpen JuliaPackaging/Yggdrasil#5797
  • Update CI.

Code to generate MIOpen wrapper using Clang.jl

using Clang.Generators
using MIOpen_jll

include_dir = normpath(MIOpen_jll.artifact_dir, "include")
miopen_dir = joinpath(include_dir, "miopen")
options = load_options("generator.toml")

args = get_default_args()
push!(args, "-I$include_dir")

headers = [
    joinpath(miopen_dir, header)
    for header in readdir(miopen_dir)
    if endswith(header, ".h")
]

ctx = create_context(headers, args, options)
build!(ctx)
[general]
module_name = "libMIOpen"
library_name = "libMIOpen_path"
output_file_path = "./libMIOpen.jl"
jll_pkg_name = "MIOpen_jll"
export_symbol_prefixes = []

- DEBUG env variables no longer seems to be necessary... :/
- Update test
@pxl-th
Copy link
Member Author

pxl-th commented Nov 3, 2022

At this moment it supports:

  • 2D & 3D regular and grouped (depthwise) convolutions
  • supported types: Float16 & Float32

For other types I get errors that look like either misconfigured MIOpen_jll or bug in MIOpen itself, not sure :/
I suspect both, maybe ROCm 5.3 - 5.4 will have more things working...

@pxl-th
Copy link
Member Author

pxl-th commented Nov 3, 2022

@jpsamaroo not sure what to do about CI.
It fails because MIOpen is not present in Manifest.toml (for buildkite) and because it was built only for Julia 1.9 (for GitHub CI).

To fix buildkite CI we can delete Manifest.toml.
And for GitHub CI we can upgrade it to 1.9...
Although that leaves out Julia 1.7...

@pxl-th pxl-th marked this pull request as ready for review November 3, 2022 20:39
@pxl-th pxl-th requested a review from jpsamaroo November 4, 2022 09:47
Copy link
Member

@jpsamaroo jpsamaroo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! My major requests are:

  • Implementing locking around global data access
  • Removing get_ prefixes

- Use NNlib for convolution comparisons.
- Add LockedObject that manages resource's lock state.
- Add MIOpen compat.
- Remove 'get_' prefixes.
@ToucheSir
Copy link

ToucheSir commented Nov 8, 2022

In case you haven't seen it already, the CUDA.CUDNN readme may be useful for discussions on function naming. That and the test code in the subpackage.

- Move LockedObject from MIOpen to AMDGPU so that other stuff can use
  it.
@pxl-th
Copy link
Member Author

pxl-th commented Nov 9, 2022

In case you haven't seen it already, the CUDA.CUDNN readme may be useful for discussions on function naming. That and the test code in the subpackage.

No I haven't seen that, thanks.
For the high-level interface I was actually aiming at being more similar to NNlib's API.
So that it is easy to use it even without NNlib.

@pxl-th
Copy link
Member Author

pxl-th commented Nov 9, 2022

Regarding the error in CI:

From worker 4:	MIOpen(HIP): Error [Do] 'amd_comgr_do_action(kind, handle, in.GetHandle(), out.GetHandle())' AMD_COMGR_ACTION_COMPILE_SOURCE_TO_BC: ERROR (1)
--
  | From worker 4:	MIOpen(HIP): Error [BuildHip] comgr status = ERROR (1)
  | From worker 4:	MIOpen(HIP): Warning [BuildHip] <built-in>:1:10: fatal error: '__clang_hip_runtime_wrapper.h' file not found
  | From worker 4:	#include "__clang_hip_runtime_wrapper.h"
  | From worker 4:	         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  | From worker 4:	1 error generated when compiling for gfx90a.

I dont have it on gfx1031.
My guess is that it could be a misconfigured MIOpen, possible solution to try: cupy/cupy#5592 (comment)
Or a bug in MIOpen itself...
I plan to update ROCm stack to either 5.3.x or 5.4 in the nearest future as it contains several improtant improvements, which might help with things.

Copy link
Member

@jpsamaroo jpsamaroo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@pxl-th pxl-th merged commit 13b4c1b into master Nov 14, 2022
@pxl-th pxl-th deleted the pxl-th/add-miopen-wrapper branch November 14, 2022 14:47
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants