Skip to content

Commit d227214

Browse files
committed
scatter cuda
try to unify the scatter API add const MAX_THREADS and correct dynamic call atomic functions
1 parent 6905ca7 commit d227214

File tree

4 files changed

+199
-0
lines changed

4 files changed

+199
-0
lines changed

lib/NNlibCUDA/src/NNlibCUDA.jl

+4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ using NNlib
44
using CUDA
55
using Random, Statistics
66

7+
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
8+
const MAX_THREADS = 1024
9+
710
include("upsample.jl")
811
include("batchedmul.jl")
12+
include("scatter.jl")
913
include("cudnn/cudnn.jl")
1014
include("cudnn/conv.jl")
1115
include("cudnn/pooling.jl")

lib/NNlibCUDA/src/scatter.jl

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!,
2+
(*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!)
3+
4+
for (op, atm_op) in ATM_OPS
5+
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx)
6+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
7+
8+
@inbounds if index <= length(idx)
9+
i = Base._to_linear_index(dst, idx[index]...)
10+
$(atm_op)(pointer(dst, i), src[index])
11+
end
12+
return nothing
13+
end
14+
15+
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{N}, max_idx, max_dims_idx, dims_size) where {N}
16+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
17+
18+
@inbounds if index <= max_idx
19+
j, k = divrem(index-1, max_dims_idx)
20+
dims_i = CartesianIndices(dims_size)[k+1]
21+
i = Base._to_linear_index(dst, Tuple(dims_i)..., idx[j+1]...)
22+
$(atm_op)(pointer(dst, i), src[index])
23+
end
24+
return nothing
25+
end
26+
27+
@eval function NNlib.scatter!(op::typeof($(op)), dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
28+
if N == 0
29+
max_idx = length(idx)
30+
threads = min(MAX_THREADS, max_idx)
31+
blocks = ceil(Int, max_idx / threads)
32+
@cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx)
33+
return dst
34+
else
35+
dims_size = size(dst)[1:N]
36+
max_dims_idx = prod(dims_size)
37+
max_idx = max_dims_idx * length(idx)
38+
threads = min(MAX_THREADS, max_idx)
39+
blocks = ceil(Int, max_idx / threads)
40+
@cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size)
41+
return dst
42+
end
43+
end
44+
end

lib/NNlibCUDA/test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using NNlib
33
using Zygote
44
using NNlibCUDA
55
using ForwardDiff: Dual
6+
using Statistics: mean
67
using CUDA
78
CUDA.allowscalar(false)
89

@@ -16,4 +17,5 @@ if CUDA.has_cuda()
1617
include("pooling.jl")
1718
include("softmax.jl")
1819
include("batchnorm.jl")
20+
include("scatter.jl")
1921
end

lib/NNlibCUDA/test/scatter.jl

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
dsts = Dict(
2+
0 => cu([3, 4, 5, 6, 7]),
3+
1 => cu([3 3 4 4 5;
4+
5 5 6 6 7]),
5+
)
6+
srcs = Dict(
7+
(0, true) => cu(ones(Int, 3, 4)),
8+
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
9+
(1, true) => cu(ones(Int, 2, 3, 4)),
10+
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
11+
)
12+
idxs = [
13+
cu([1 2 3 4;
14+
4 2 1 3;
15+
3 5 5 3]), # integer index
16+
cu([(1,) (2,) (3,) (4,);
17+
(4,) (2,) (1,) (3,);
18+
(3,) (5,) (5,) (3,)]), # tuple index
19+
]
20+
res = Dict(
21+
(+, 0, true) => cu([5, 6, 9, 8, 9]),
22+
(+, 1, true) => cu([5 5 8 6 7;
23+
7 7 10 8 9]),
24+
(+, 0, false) => cu([4, 4, 12, 5, 5]),
25+
(+, 1, false) => cu([4 4 12 5 5;
26+
8 8 24 10 10]),
27+
(-, 0, true) => cu([1, 2, 1, 4, 5]),
28+
(-, 1, true) => cu([1 1 0 2 3;
29+
3 3 2 4 5]),
30+
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
31+
(-, 1, false) => cu([-4 -4 -12 -5 -5;
32+
-8 -8 -24 -10 -10]),
33+
(max, 0, true) => cu([3, 4, 5, 6, 7]),
34+
(max, 1, true) => cu([3 3 4 4 5;
35+
5 5 6 6 7]),
36+
(max, 0, false) => cu([3, 2, 4, 4, 3]),
37+
(max, 1, false) => cu([3 2 4 4 3;
38+
6 4 8 8 6]),
39+
(min, 0, true) => cu([1, 1, 1, 1, 1]),
40+
(min, 1, true) => cu([1 1 1 1 1;
41+
1 1 1 1 1]),
42+
(min, 0, false) => cu([1, 2, 1, 1, 2]),
43+
(min, 1, false) => cu([1 2 1 1 2;
44+
2 4 2 2 4]),
45+
(*, 0, true) => cu([3, 4, 5, 6, 7]),
46+
(*, 1, true) => cu([3 3 4 4 5;
47+
5 5 6 6 7]),
48+
(*, 0, false) => cu([3, 4, 48, 4, 6]),
49+
(*, 1, false) => cu([3 4 48 4 6;
50+
12 16 768 16 24]),
51+
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
52+
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
53+
1.25 1.25 0.375 1.5 1.75]),
54+
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
55+
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
56+
1//12 1//16 1//768 1//16 1//24]),
57+
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
58+
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
59+
6. 6. 7. 7. 8.]),
60+
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
61+
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
62+
4. 4. 6. 5. 5.]),
63+
)
64+
65+
types = [CuArray{UInt32}, CuArray{UInt64},
66+
CuArray{Int32}, CuArray{Int64},
67+
CuArray{Float32}, CuArray{Float64}]
68+
69+
70+
@testset "scatter" begin
71+
for T = types
72+
@testset "$(T)" begin
73+
@testset "+" begin
74+
for idx = idxs, dims = [0, 1]
75+
mutated = true
76+
@test scatter!(+, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])
77+
78+
mutated = false
79+
# @test scatter(+, srcs[(dims, mutated)], idx) == T(res[(+, dims, mutated)])
80+
end
81+
end
82+
83+
@testset "-" begin
84+
for idx = idxs, dims = [0, 1]
85+
mutated = true
86+
@test scatter!(-, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])
87+
88+
mutated = false
89+
# @test scatter(-, srcs[(dims, mutated)], idx) == T(res[(-, dims, mutated)])
90+
end
91+
end
92+
93+
@testset "max" begin
94+
for idx = idxs, dims = [0, 1]
95+
mutated = true
96+
@test scatter!(max, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])
97+
98+
mutated = false
99+
# @test scatter(max, srcs[(dims, mutated)], idx) == T(res[(max, dims, mutated)])
100+
end
101+
end
102+
103+
@testset "min" begin
104+
for idx = idxs, dims = [0, 1]
105+
mutated = true
106+
@test scatter!(min, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])
107+
108+
mutated = false
109+
# @test scatter(min, srcs[(dims, mutated)], idx) == T(res[(min, dims, mutated)])
110+
end
111+
end
112+
end
113+
end
114+
115+
116+
for T = [CuArray{Float32}, CuArray{Float64}]
117+
@testset "$(T)" begin
118+
@testset "*" begin
119+
for idx = idxs, dims = [0, 1]
120+
mutated = true
121+
@test scatter!(*, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])
122+
123+
mutated = false
124+
# @test scatter(*, srcs[(dims, mutated)], idx) == T(res[(*, dims, mutated)])
125+
end
126+
end
127+
128+
@testset "/" begin
129+
for idx = idxs, dims = [0, 1]
130+
mutated = true
131+
@test scatter!(/, T(dsts[dims]), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)])
132+
133+
mutated = false
134+
# @test scatter(/, srcs[(dims, mutated)], idx) == T(res[(/, dims, mutated)])
135+
end
136+
end
137+
138+
@testset "mean" begin
139+
for idx = idxs, dims = [0, 1]
140+
mutated = true
141+
@test scatter!(mean, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])
142+
143+
mutated = false
144+
# @test scatter(mean, srcs[(dims, mutated)], idx) == T(res[(mean, dims, mutated)])
145+
end
146+
end
147+
end
148+
end
149+
end

0 commit comments

Comments
 (0)