Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add scatter for CUDA support #1

Merged
merged 13 commits into from
May 2, 2021
3 changes: 3 additions & 0 deletions src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
@@ -4,9 +4,12 @@ using NNlib
using CUDA
using Random, Statistics

const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}

include("upsample.jl")
include("activations.jl")
include("batchedmul.jl")
include("scatter.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
include("cudnn/pooling.jl")
40 changes: 40 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
for op in [+, -, *, /, max, min, &, |]
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= length(idx)
@atomic dst[idx[index]...] = $(op)(dst[idx[index]...], src[index])
end
return nothing
end

@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{N}, max_idx, max_dims_idx, dims_size) where {N}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = CartesianIndices(dims_size)[k+1]
@atomic dst[Tuple(dims_i)..., idx[j+1]...] = $(op)(dst[Tuple(dims_i)..., idx[j+1]...], src[index])
end
return nothing
end

@eval function NNlib.scatter!(op::typeof($(op)), dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
args = if N == 0
max_idx = length(idx)
op, dst, src, idx
else
dims_size = size(dst)[1:N]
max_dims_idx = prod(dims_size)
max_idx = max_dims_idx * length(idx)
op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size
end

kernel = @cuda launch=false scatter_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(max_idx, config.threads)
blocks = ceil(Int, max_idx / threads)
kernel(args...; threads=threads, blocks=blocks)
return dst
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ using NNlib
using Zygote
using NNlibCUDA
using ForwardDiff: Dual
using Statistics: mean
using CUDA
CUDA.allowscalar(false)

@@ -16,4 +17,5 @@ if CUDA.has_cuda()
include("pooling.jl")
include("softmax.jl")
include("batchnorm.jl")
include("scatter.jl")
end
149 changes: 149 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
dsts = Dict(
0 => cu([3, 4, 5, 6, 7]),
1 => cu([3 3 4 4 5;
5 5 6 6 7]),
)
srcs = Dict(
(0, true) => cu(ones(Int, 3, 4)),
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
(1, true) => cu(ones(Int, 2, 3, 4)),
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
)
idxs = [
cu([1 2 3 4;
4 2 1 3;
3 5 5 3]), # integer index
cu([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]), # tuple index
]
res = Dict(
(+, 0, true) => cu([5, 6, 9, 8, 9]),
(+, 1, true) => cu([5 5 8 6 7;
7 7 10 8 9]),
(+, 0, false) => cu([4, 4, 12, 5, 5]),
(+, 1, false) => cu([4 4 12 5 5;
8 8 24 10 10]),
(-, 0, true) => cu([1, 2, 1, 4, 5]),
(-, 1, true) => cu([1 1 0 2 3;
3 3 2 4 5]),
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
(-, 1, false) => cu([-4 -4 -12 -5 -5;
-8 -8 -24 -10 -10]),
(max, 0, true) => cu([3, 4, 5, 6, 7]),
(max, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(max, 0, false) => cu([3, 2, 4, 4, 3]),
(max, 1, false) => cu([3 2 4 4 3;
6 4 8 8 6]),
(min, 0, true) => cu([1, 1, 1, 1, 1]),
(min, 1, true) => cu([1 1 1 1 1;
1 1 1 1 1]),
(min, 0, false) => cu([1, 2, 1, 1, 2]),
(min, 1, false) => cu([1 2 1 1 2;
2 4 2 2 4]),
(*, 0, true) => cu([3, 4, 5, 6, 7]),
(*, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(*, 0, false) => cu([3, 4, 48, 4, 6]),
(*, 1, false) => cu([3 4 48 4 6;
12 16 768 16 24]),
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75]),
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
1//12 1//16 1//768 1//16 1//24]),
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.]),
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
4. 4. 6. 5. 5.]),
)

types = [CuArray{UInt32}, CuArray{UInt64},
CuArray{Int32}, CuArray{Int64},
CuArray{Float32}, CuArray{Float64}]


@testset "scatter" begin
for T = types
@testset "$(T)" begin
@testset "+" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(+, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])

mutated = false
# @test scatter(+, srcs[(dims, mutated)], idx) == T(res[(+, dims, mutated)])
end
end

@testset "-" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(-, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])

mutated = false
# @test scatter(-, srcs[(dims, mutated)], idx) == T(res[(-, dims, mutated)])
end
end

@testset "max" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(max, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])

mutated = false
# @test scatter(max, srcs[(dims, mutated)], idx) == T(res[(max, dims, mutated)])
end
end

@testset "min" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(min, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])

mutated = false
# @test scatter(min, srcs[(dims, mutated)], idx) == T(res[(min, dims, mutated)])
end
end
end
end


for T = [CuArray{Float32}, CuArray{Float64}]
@testset "$(T)" begin
@testset "*" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(*, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])

mutated = false
# @test scatter(*, srcs[(dims, mutated)], idx) == T(res[(*, dims, mutated)])
end
end

@testset "/" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(/, T(dsts[dims]), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)])

mutated = false
# @test scatter(/, srcs[(dims, mutated)], idx) == T(res[(/, dims, mutated)])
end
end

@testset "mean" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(mean, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])

mutated = false
# @test scatter(mean, srcs[(dims, mutated)], idx) == T(res[(mean, dims, mutated)])
end
end
end
end
end