Skip to content

Commit 94ea8b3

Browse files
Merge pull request #1 from yuehhua/scatter
Add scatter for CUDA support
2 parents 84bdd6a + 1bd64c8 commit 94ea8b3

7 files changed

+204
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Manifest.toml

Manifest.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
167167

168168
[[NNlib]]
169169
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
170-
git-tree-sha1 = "723c0d5252bf95808f934b2384519dd325869f40"
170+
git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e"
171171
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
172-
version = "0.7.18"
172+
version = "0.7.19"
173173

174174
[[NetworkOptions]]
175175
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.1"
14-
NNlib = "0.7"
14+
NNlib = "0.7.19"
1515
julia = "1.6"
1616

1717
[extras]

src/NNlibCUDA.jl

+3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ using NNlib
44
using CUDA
55
using Random, Statistics
66

7+
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
8+
79
include("upsample.jl")
810
include("activations.jl")
911
include("batchedmul.jl")
12+
include("scatter.jl")
1013
include("cudnn/cudnn.jl")
1114
include("cudnn/conv.jl")
1215
include("cudnn/pooling.jl")

src/scatter.jl

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# supported op: +, -, *, /, max, min, &, |, mean
2+
3+
function scatter_kernel!(op, dst, src, idx)
4+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
5+
6+
@inbounds if index <= length(idx)
7+
@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index])
8+
end
9+
return nothing
10+
end
11+
12+
function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
13+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
14+
15+
@inbounds if index <= max_idx
16+
j, k = divrem(index-1, max_dims_idx)
17+
dims_i = CartesianIndices(dims_size)[k+1]
18+
@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index])
19+
end
20+
return nothing
21+
end
22+
23+
function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
24+
dims = NNlib._check_dims(dst, src, idx)
25+
args = if dims == 0
26+
max_idx = length(idx)
27+
op, dst, src, idx
28+
else
29+
dims_size = size(dst)[1:dims]
30+
max_dims_idx = prod(dims_size)
31+
max_idx = max_dims_idx * length(idx)
32+
op, dst, src, idx, max_idx, max_dims_idx, dims_size
33+
end
34+
35+
kernel = @cuda launch=false scatter_kernel!(args...)
36+
config = launch_configuration(kernel.fun; max_threads=256)
37+
threads = min(max_idx, config.threads)
38+
blocks = cld(max_idx, threads)
39+
kernel(args...; threads=threads, blocks=blocks)
40+
return dst
41+
end
42+
43+
function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
44+
Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
45+
dst_ = NNlib.scatter!(+, zero(dst), src, idx)
46+
dst .+= NNlib.safe_div.(dst_, Ns)
47+
return dst
48+
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using NNlib
33
using Zygote
44
using NNlibCUDA
55
using ForwardDiff: Dual
6+
using Statistics: mean
67
using CUDA
78
CUDA.allowscalar(false)
89

@@ -16,4 +17,5 @@ if CUDA.has_cuda()
1617
include("pooling.jl")
1718
include("softmax.jl")
1819
include("batchnorm.jl")
20+
include("scatter.jl")
1921
end

test/scatter.jl

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
dsts = Dict(
2+
0 => cu([3, 4, 5, 6, 7]),
3+
1 => cu([3 3 4 4 5;
4+
5 5 6 6 7]),
5+
)
6+
srcs = Dict(
7+
(0, true) => cu(ones(Int, 3, 4)),
8+
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
9+
(1, true) => cu(ones(Int, 2, 3, 4)),
10+
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
11+
)
12+
idxs = [
13+
cu([1 2 3 4;
14+
4 2 1 3;
15+
3 5 5 3]), # integer index
16+
cu([(1,) (2,) (3,) (4,);
17+
(4,) (2,) (1,) (3,);
18+
(3,) (5,) (5,) (3,)]), # tuple index
19+
]
20+
res = Dict(
21+
(+, 0, true) => cu([5, 6, 9, 8, 9]),
22+
(+, 1, true) => cu([5 5 8 6 7;
23+
7 7 10 8 9]),
24+
(+, 0, false) => cu([4, 4, 12, 5, 5]),
25+
(+, 1, false) => cu([4 4 12 5 5;
26+
8 8 24 10 10]),
27+
(-, 0, true) => cu([1, 2, 1, 4, 5]),
28+
(-, 1, true) => cu([1 1 0 2 3;
29+
3 3 2 4 5]),
30+
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
31+
(-, 1, false) => cu([-4 -4 -12 -5 -5;
32+
-8 -8 -24 -10 -10]),
33+
(max, 0, true) => cu([3, 4, 5, 6, 7]),
34+
(max, 1, true) => cu([3 3 4 4 5;
35+
5 5 6 6 7]),
36+
(max, 0, false) => cu([3, 2, 4, 4, 3]),
37+
(max, 1, false) => cu([3 2 4 4 3;
38+
6 4 8 8 6]),
39+
(min, 0, true) => cu([1, 1, 1, 1, 1]),
40+
(min, 1, true) => cu([1 1 1 1 1;
41+
1 1 1 1 1]),
42+
(min, 0, false) => cu([1, 2, 1, 1, 2]),
43+
(min, 1, false) => cu([1 2 1 1 2;
44+
2 4 2 2 4]),
45+
(*, 0, true) => cu([3, 4, 5, 6, 7]),
46+
(*, 1, true) => cu([3 3 4 4 5;
47+
5 5 6 6 7]),
48+
(*, 0, false) => cu([3, 4, 48, 4, 6]),
49+
(*, 1, false) => cu([3 4 48 4 6;
50+
12 16 768 16 24]),
51+
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
52+
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
53+
1.25 1.25 0.375 1.5 1.75]),
54+
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
55+
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
56+
1//12 1//16 1//768 1//16 1//24]),
57+
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
58+
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
59+
6. 6. 7. 7. 8.]),
60+
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
61+
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
62+
4. 4. 6. 5. 5.]),
63+
)
64+
65+
types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]
66+
67+
68+
@testset "scatter" begin
69+
for T = types
70+
@testset "$(T)" begin
71+
@testset "+" begin
72+
for idx = idxs, dims = [0, 1]
73+
mutated = true
74+
@test NNlib.scatter!(+, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])
75+
76+
mutated = false
77+
@test NNlib.scatter(+, T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])
78+
end
79+
end
80+
81+
@testset "-" begin
82+
for idx = idxs, dims = [0, 1]
83+
mutated = true
84+
@test NNlib.scatter!(-, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])
85+
86+
mutated = false
87+
@test NNlib.scatter(-, T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])
88+
end
89+
end
90+
91+
@testset "max" begin
92+
for idx = idxs, dims = [0, 1]
93+
mutated = true
94+
@test NNlib.scatter!(max, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])
95+
96+
mutated = false
97+
@test NNlib.scatter(max, T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])
98+
end
99+
end
100+
101+
@testset "min" begin
102+
for idx = idxs, dims = [0, 1]
103+
mutated = true
104+
@test NNlib.scatter!(min, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])
105+
106+
mutated = false
107+
@test NNlib.scatter(min, T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])
108+
end
109+
end
110+
end
111+
end
112+
113+
114+
for T = [CuArray{Float32}, CuArray{Float64}]
115+
@testset "$(T)" begin
116+
@testset "*" begin
117+
for idx = idxs, dims = [0, 1]
118+
mutated = true
119+
@test NNlib.scatter!(*, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])
120+
121+
mutated = false
122+
@test NNlib.scatter(*, T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])
123+
end
124+
end
125+
126+
@testset "/" begin
127+
for idx = idxs, dims = [0, 1]
128+
mutated = true
129+
@test NNlib.scatter!(/, T(copy(dsts[dims])), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)])
130+
131+
mutated = false
132+
@test NNlib.scatter(/, T(srcs[(dims, mutated)]), idx) == T(res[(/, dims, mutated)])
133+
end
134+
end
135+
136+
@testset "mean" begin
137+
for idx = idxs, dims = [0, 1]
138+
mutated = true
139+
@test NNlib.scatter!(mean, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])
140+
141+
mutated = false
142+
@test NNlib.scatter(mean, T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])
143+
end
144+
end
145+
end
146+
end
147+
end

0 commit comments

Comments
 (0)