Skip to content

Commit 2a4b220

Browse files
Merge pull request #8 from yuehhua/gather
Gather for CUDA support
2 parents 94ea8b3 + c5da628 commit 2a4b220

File tree

5 files changed

+157
-24
lines changed

5 files changed

+157
-24
lines changed

Manifest.toml

+42-24
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
3333
version = "0.4.1"
3434

3535
[[CUDA]]
36-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"]
37-
git-tree-sha1 = "d4fa6486e94c4087f1d081d7be2d501a170bd51d"
36+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
37+
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
3838
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
39-
version = "3.1.0"
39+
version = "3.2.1"
4040

4141
[[ChainRulesCore]]
4242
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
43-
git-tree-sha1 = "a66109c73612c63b10923ac446fddb0f0d21a593"
43+
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
4444
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
45-
version = "0.9.40"
45+
version = "0.9.44"
4646

4747
[[Compat]]
4848
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
49-
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
49+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
5050
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
51-
version = "3.27.0"
51+
version = "3.30.0"
5252

5353
[[CompilerSupportLibraries_jll]]
5454
deps = ["Artifacts", "Libdl"]
@@ -72,6 +72,12 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
7272
deps = ["Random", "Serialization", "Sockets"]
7373
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
7474

75+
[[DocStringExtensions]]
76+
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
77+
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
78+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
79+
version = "0.8.4"
80+
7581
[[Downloads]]
7682
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
7783
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
@@ -82,16 +88,16 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
8288
version = "0.1.3"
8389

8490
[[GPUArrays]]
85-
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
86-
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
91+
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
92+
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
8793
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
88-
version = "6.2.2"
94+
version = "6.4.1"
8995

9096
[[GPUCompiler]]
9197
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
92-
git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6"
98+
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
9399
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
94-
version = "0.11.4"
100+
version = "0.11.5"
95101

96102
[[InteractiveUtils]]
97103
deps = ["Markdown"]
@@ -105,9 +111,9 @@ version = "1.3.0"
105111

106112
[[LLVM]]
107113
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
108-
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
114+
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
109115
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
110-
version = "3.6.0"
116+
version = "3.7.1"
111117

112118
[[LazyArtifacts]]
113119
deps = ["Artifacts", "Pkg"]
@@ -136,6 +142,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
136142
deps = ["Libdl"]
137143
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
138144

145+
[[LogExpFunctions]]
146+
deps = ["DocStringExtensions", "LinearAlgebra"]
147+
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
148+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
149+
version = "0.2.4"
150+
139151
[[Logging]]
140152
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
141153

@@ -181,19 +193,19 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
181193
version = "0.5.4+0"
182194

183195
[[OrderedCollections]]
184-
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
196+
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
185197
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
186-
version = "1.4.0"
198+
version = "1.4.1"
187199

188200
[[Pkg]]
189201
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
190202
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
191203

192204
[[Preferences]]
193205
deps = ["TOML"]
194-
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
206+
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
195207
uuid = "21216c6a-2e73-6563-6e65-726566657250"
196-
version = "1.2.1"
208+
version = "1.2.2"
197209

198210
[[Printf]]
199211
deps = ["Unicode"]
@@ -207,6 +219,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
207219
deps = ["Serialization"]
208220
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
209221

222+
[[Random123]]
223+
deps = ["Libdl", "Random", "RandomNumbers"]
224+
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
225+
uuid = "74087812-796a-5b5d-8853-05524746bad3"
226+
version = "1.3.1"
227+
210228
[[RandomNumbers]]
211229
deps = ["Random", "Requires"]
212230
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
@@ -248,10 +266,10 @@ deps = ["LinearAlgebra", "Random"]
248266
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
249267

250268
[[SpecialFunctions]]
251-
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
252-
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
269+
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
270+
git-tree-sha1 = "c467f25b6ec4167ea3a9a4351c66c2e1cba5da33"
253271
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
254-
version = "1.3.0"
272+
version = "1.4.1"
255273

256274
[[Statistics]]
257275
deps = ["LinearAlgebra", "SparseArrays"]
@@ -270,10 +288,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
270288
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
271289

272290
[[TimerOutputs]]
273-
deps = ["Printf"]
274-
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
291+
deps = ["ExprTools", "Printf"]
292+
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
275293
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
276-
version = "0.5.8"
294+
version = "0.5.9"
277295

278296
[[UUIDs]]
279297
deps = ["Random", "SHA"]

src/NNlibCUDA.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("upsample.jl")
1010
include("activations.jl")
1111
include("batchedmul.jl")
1212
include("scatter.jl")
13+
include("gather.jl")
1314
include("cudnn/cudnn.jl")
1415
include("cudnn/conv.jl")
1516
include("cudnn/pooling.jl")

src/gather.jl

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
function gather_check_dims(X::AbstractArray{Tx,Nx},
2+
Y::AbstractArray{Ty,Ny},
3+
idx::AbstractArray{Tidx,Nidx}) where
4+
{Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx}
5+
M = NNlib.typelength(Tidx)
6+
dims = gather_check_dims(Nx, Ny, M, Nidx)
7+
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
8+
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
9+
return dims
10+
end
11+
12+
function gather_check_dims(X::AbstractArray{Tx,Nx},
13+
Y::AbstractArray{Ty,Ny},
14+
idx::AbstractArray{CartesianIndex{M},Nidx}) where
15+
{Tx,Ty,Nx,Ny,M,Nidx}
16+
dims = gather_check_dims(Nx, Ny, M, Nidx)
17+
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
18+
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
19+
return dims
20+
end
21+
22+
function gather_check_dims(Nx, Ny, M, Nidx)
23+
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
24+
dims = Nx - M
25+
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
26+
return dims
27+
end
28+
29+
function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size)
30+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
31+
32+
@inbounds if index <= max_idx
33+
j, k = divrem(index-1, max_dims_idx)
34+
dims_i = CartesianIndices(dims_size)[k+1]
35+
dst[index] = src[dims_i, idx[j+1]...]
36+
end
37+
return nothing
38+
end
39+
40+
function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
41+
dims = gather_check_dims(src, dst, idx)
42+
dims_size = size(src)[1:dims]
43+
max_dims_idx = prod(dims_size)
44+
max_idx = max_dims_idx * length(idx)
45+
args = dst, src, idx, max_idx, max_dims_idx, dims_size
46+
47+
kernel = @cuda launch=false gather_kernel!(args...)
48+
config = launch_configuration(kernel.fun; max_threads=256)
49+
threads = min(max_idx, config.threads)
50+
blocks = cld(max_idx, threads)
51+
kernel(args...; threads=threads, blocks=blocks)
52+
return dst
53+
end

test/gather.jl

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
@testset "gather" begin
2+
T = Float32
3+
CT = CuArray{Float32}
4+
5+
## 1d src, 2d index of ints -> 2d output
6+
src = CT([3, 4, 5, 6, 7])
7+
index = cu([1 2 3 4;
8+
4 2 1 3;
9+
3 5 5 3])
10+
output = CT([3 4 5 6;
11+
6 4 3 5;
12+
5 7 7 5])
13+
14+
y = NNlib.gather(src, index)
15+
@test y isa CuArray{Float32,2}
16+
@test size(y) == size(index)
17+
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
18+
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
19+
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20+
21+
## 1d src, 3d index of ints -> 3d output
22+
src = CT([3, 4, 5, 6, 7])
23+
index = cu([1 2 3 4;
24+
4 2 1 3;
25+
3 5 5 3][:,:,1:1])
26+
output = CT([3 4 5 6;
27+
6 4 3 5;
28+
5 7 7 5][:,:,1:1])
29+
30+
y = NNlib.gather(src, index)
31+
@test y isa CuArray{Float32,3}
32+
@test size(y) == size(index)
33+
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
34+
35+
36+
## 2d src, 2d index of ints -> 3d output
37+
src = CT([3 5 7
38+
4 6 8])
39+
index = cu([1 2 3;
40+
2 2 1;
41+
3 1 3])
42+
43+
output = zeros(T, 2, 3, 3)
44+
45+
output[:,:,1] = [3 5 7
46+
4 6 8]
47+
48+
output[:,:,2] = [5 5 3
49+
6 6 4]
50+
51+
output[:,:,3] = [7 3 7
52+
8 4 8]
53+
54+
y = NNlib.gather(src, index)
55+
M = NNlib.typelength(eltype(index))
56+
Nsrc = ndims(src)
57+
@test y isa CuArray{Float32,3}
58+
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
59+
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
60+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ if CUDA.has_cuda()
1818
include("softmax.jl")
1919
include("batchnorm.jl")
2020
include("scatter.jl")
21+
include("gather.jl")
2122
end

0 commit comments

Comments
 (0)