This repository was archived by the owner on Nov 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathMLDataDevicesAMDGPUExt.jl
97 lines (81 loc) · 3.13 KB
/
MLDataDevicesAMDGPUExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
module MLDataDevicesAMDGPUExt
using Adapt: Adapt
using AMDGPU: AMDGPU
using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random
__init__() = reset_gpu_device!()
# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package.
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)
function check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return
USE_AMD_GPU[] = AMDGPU.functional()
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \
available." maxlog=1
end
return
end
MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
check_use_amdgpu!()
return USE_AMD_GPU[]
end
Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing)
function Internal.with_device(::Type{AMDGPUDevice}, id::Integer)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
AMDGPU.device!(AMDGPU.devices()[id])
device = AMDGPUDevice(AMDGPU.device())
AMDGPU.device!(old_dev)
return device
end
Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)
# Default RNG
MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()
# Query Device from Array
function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return Internal.get_device(parent_x)
end
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())
Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice
# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
return AMDGPU.device!(dev)
end
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, id::Integer)
return MLDataDevices.set_device!(AMDGPUDevice, AMDGPU.devices()[id])
end
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(AMDGPU.devices()))
return MLDataDevices.set_device!(AMDGPUDevice, id)
end
# unsafe_free!
function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray)
AMDGPU.unsafe_free!(x)
return
end
# Device Transfer
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
dev = MLDataDevices.get_device(x)
if !(dev isa AMDGPUDevice)
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)
return x
else
AMDGPU.device!(to.device)
x_new = copy(x)
AMDGPU.device!(old_dev)
return x_new
end
end
Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()
end