Skip to content

Commit c3e1331

Browse files
Merge pull request #13 from yuehhua/scatter
Support scatter for CUDA gradient
2 parents ba9a1c0 + 60585e2 commit c3e1331

File tree

6 files changed

+167
-100
lines changed

6 files changed

+167
-100
lines changed

Manifest.toml

+23-41
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
3333
version = "0.4.1"
3434

3535
[[CUDA]]
36-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
37-
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
36+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
37+
git-tree-sha1 = "82b2811f5888465d96b38c7bb12d8fb9c25838e1"
3838
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
39-
version = "3.2.1"
39+
version = "3.3.1"
4040

4141
[[ChainRulesCore]]
4242
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
43-
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
43+
git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
4444
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
45-
version = "0.10.2"
45+
version = "0.10.9"
4646

4747
[[Compat]]
4848
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
49-
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
49+
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
5050
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
51-
version = "3.30.0"
51+
version = "3.31.0"
5252

5353
[[CompilerSupportLibraries_jll]]
5454
deps = ["Artifacts", "Libdl"]
@@ -73,10 +73,10 @@ deps = ["Random", "Serialization", "Sockets"]
7373
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
7474

7575
[[DocStringExtensions]]
76-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
77-
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
76+
deps = ["LibGit2"]
77+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
7878
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
79-
version = "0.8.4"
79+
version = "0.8.5"
8080

8181
[[Downloads]]
8282
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
@@ -89,15 +89,15 @@ version = "0.1.3"
8989

9090
[[GPUArrays]]
9191
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
92-
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
92+
git-tree-sha1 = "ececbf05f8904c92814bdbd0aafd5540b0bf2e9a"
9393
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
94-
version = "6.4.1"
94+
version = "7.0.1"
9595

9696
[[GPUCompiler]]
97-
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
98-
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
97+
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
98+
git-tree-sha1 = "222c6cdb888ec24795936d6829aa978691def60e"
9999
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
100-
version = "0.11.5"
100+
version = "0.12.3"
101101

102102
[[InteractiveUtils]]
103103
deps = ["Markdown"]
@@ -111,9 +111,9 @@ version = "1.3.0"
111111

112112
[[LLVM]]
113113
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
114-
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
114+
git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e"
115115
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
116-
version = "3.7.1"
116+
version = "3.9.0"
117117

118118
[[LazyArtifacts]]
119119
deps = ["Artifacts", "Pkg"]
@@ -151,12 +151,6 @@ version = "0.2.4"
151151
[[Logging]]
152152
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
153153

154-
[[MacroTools]]
155-
deps = ["Markdown", "Random"]
156-
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
157-
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
158-
version = "0.5.6"
159-
160154
[[Markdown]]
161155
deps = ["Base64"]
162156
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -165,12 +159,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
165159
deps = ["Artifacts", "Libdl"]
166160
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
167161

168-
[[Memoize]]
169-
deps = ["MacroTools"]
170-
git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa"
171-
uuid = "c03570c3-d221-55d1-a50c-7939bbd78826"
172-
version = "0.4.4"
173-
174162
[[Mmap]]
175163
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
176164

@@ -179,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
179167

180168
[[NNlib]]
181169
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
182-
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
170+
git-tree-sha1 = "7e6f31cfa39b1ff1c541cc8580b14b0ff4ba22d0"
183171
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
184-
version = "0.7.21"
172+
version = "0.7.23"
185173

186174
[[NetworkOptions]]
187175
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -221,9 +209,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
221209

222210
[[Random123]]
223211
deps = ["Libdl", "Random", "RandomNumbers"]
224-
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
212+
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
225213
uuid = "74087812-796a-5b5d-8853-05524746bad3"
226-
version = "1.3.1"
214+
version = "1.4.2"
227215

228216
[[RandomNumbers]]
229217
deps = ["Random", "Requires"]
@@ -245,12 +233,6 @@ version = "1.1.3"
245233
[[SHA]]
246234
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
247235

248-
[[Scratch]]
249-
deps = ["Dates"]
250-
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
251-
uuid = "6c6a2e73-6563-6170-7368-637461726353"
252-
version = "1.0.3"
253-
254236
[[Serialization]]
255237
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
256238

@@ -289,9 +271,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
289271

290272
[[TimerOutputs]]
291273
deps = ["ExprTools", "Printf"]
292-
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
274+
git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1"
293275
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
294-
version = "0.5.9"
276+
version = "0.5.10"
295277

296278
[[UUIDs]]
297279
deps = ["Random", "SHA"]

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.3.1"
14-
NNlib = "0.7.21"
14+
NNlib = "0.7.23"
1515
julia = "1.6"
1616

1717
[extras]

src/NNlibCUDA.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("activations.jl")
1111
include("batchedmul.jl")
1212
include("scatter.jl")
1313
include("gather.jl")
14+
include("utils.jl")
1415
include("cudnn/cudnn.jl")
1516
include("cudnn/conv.jl")
1617
include("cudnn/pooling.jl")

src/scatter.jl

+72
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,75 @@ function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx:
4646
dst .+= NNlib.safe_div.(dst_, Ns)
4747
return dst
4848
end
49+
50+
51+
## Gradients
52+
53+
function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
54+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
55+
56+
@inbounds if index <= max_idx
57+
cart_j = CartesianIndices(idx)[index]
58+
# get aggregating indeices, which is to be aggregated together, and itself index
59+
inds = rev_idx[idx[cart_j]...]
60+
# multiply all values to be aggregated but not itself
61+
x = one(T)
62+
for k in inds
63+
x *= src[k]
64+
end
65+
x /= src[cart_j]
66+
# apply `op` on `Δsrc[i, k]` and `x`
67+
Δsrc[cart_j] = op(Δsrc[cart_j], x)
68+
end
69+
return nothing
70+
end
71+
72+
function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
73+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
74+
75+
@inbounds if index <= max_idx
76+
i, j = fldmod1(index, max_dims_idx)
77+
cart_i = CartesianIndices(idx)[i]
78+
cart_j = pre_cart_idx[j]
79+
# get aggregating indeices, which is to be aggregated together, and itself index
80+
inds = rev_idx[idx[cart_i]...]
81+
# multiply all values to be aggregated but not itself
82+
x = one(T)
83+
for k in inds
84+
jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
85+
x *= src[jk]
86+
end
87+
x /= src[index]
88+
# apply `op` on `Δsrc[i, k]` and `x`
89+
Δsrc[index] = op(Δsrc[index], x)
90+
end
91+
return nothing
92+
end
93+
94+
function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
95+
src::AnyCuArray{Tsrc,Nsrc},
96+
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
97+
dims = Nsrc - Nidx
98+
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
99+
rev_idx = NNlib.reverse_indices(idx)
100+
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))
101+
102+
if dims == 0
103+
max_idx = length(idx)
104+
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
105+
else
106+
pre_cart_idx = CartesianIndices(axes(src)[1:dims])
107+
max_dims_idx = length(pre_cart_idx)
108+
max_idx = max_dims_idx * length(idx)
109+
args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc
110+
end
111+
112+
kernel = @cuda launch=false ∇scatter_src_kernel!(args...)
113+
config = launch_configuration(kernel.fun; max_threads=256)
114+
threads = min(max_idx, config.threads)
115+
blocks = cld(max_idx, threads)
116+
kernel(args...; threads=threads, blocks=blocks)
117+
118+
CUDA.unsafe_free!(rev_idx)
119+
return Δsrc
120+
end

src/utils.jl

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
function NNlib.count_indices(idx::AnyCuArray)
2+
dst_counts = length.(NNlib.reverse_indices(idx))
3+
src_counts = NNlib.gather(cu(dst_counts), idx)
4+
return src_counts
5+
end
6+
7+
function divide_kernel!(xs, ys, max_idx)
8+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
9+
10+
@inbounds if index <= max_idx
11+
xs[index] = xs[index] / ys[index]
12+
end
13+
return nothing
14+
end
15+
16+
function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)
17+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
18+
19+
@inbounds if index <= max_idx
20+
j, k = divrem(index-1, max_dims_idx)
21+
dims_i = Tuple(CartesianIndices(dims_size)[k+1])
22+
@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]
23+
end
24+
return nothing
25+
end
26+
27+
function NNlib.divide_by_counts!(xs::AnyCuArray{T}, idx::AnyCuArray, dims) where {T}
28+
counts = CuArray{T}(NNlib.count_indices(idx))
29+
args = if dims == 0
30+
max_idx = length(idx)
31+
xs, counts, max_idx
32+
else
33+
dims_size = size(xs)[1:dims]
34+
max_dims_idx = prod(dims_size)
35+
max_idx = prod(size(xs))
36+
xs, counts, max_idx, max_dims_idx, dims_size
37+
end
38+
39+
kernel = @cuda launch=false divide_kernel!(args...)
40+
config = launch_configuration(kernel.fun; max_threads=256)
41+
threads = min(max_idx, config.threads)
42+
blocks = cld(max_idx, threads)
43+
kernel(args...; threads=threads, blocks=blocks)
44+
return xs
45+
end
46+
47+
function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
48+
max_dims = NNlib.maximum_dims(idx)
49+
T = CartesianIndex{N}
50+
rev = Array{Vector{T}}(undef, max_dims...)
51+
for i in eachindex(rev)
52+
rev[i] = T[]
53+
end
54+
NNlib.reverse_indices!(rev, idx)
55+
return map(cu, rev)
56+
end

0 commit comments

Comments
 (0)