diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba39cc5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml index c9f84b6..dc33e60 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -167,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "723c0d5252bf95808f934b2384519dd325869f40" +git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.18" +version = "0.7.19" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" diff --git a/Project.toml b/Project.toml index 6bca04b..7a1bb30 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] CUDA = "3.1" -NNlib = "0.7" +NNlib = "0.7.19" julia = "1.6" [extras] diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl index 87dc168..b4cc05c 100644 --- a/src/NNlibCUDA.jl +++ b/src/NNlibCUDA.jl @@ -4,9 +4,12 @@ using NNlib using CUDA using Random, Statistics +const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} + include("upsample.jl") include("activations.jl") include("batchedmul.jl") +include("scatter.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") include("cudnn/pooling.jl") diff --git a/src/scatter.jl b/src/scatter.jl new file mode 100644 index 0000000..6f81211 --- /dev/null +++ b/src/scatter.jl @@ -0,0 +1,48 @@ +# supported op: +, -, *, /, max, min, &, |, mean + +function scatter_kernel!(op, dst, src, idx) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= length(idx) + @atomic dst[idx[index]...] = op(dst[idx[index]...], src[index]) + end + return nothing +end + +function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = CartesianIndices(dims_size)[k+1] + @atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) + end + return nothing +end + +function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + dims = NNlib._check_dims(dst, src, idx) + args = if dims == 0 + max_idx = length(idx) + op, dst, src, idx + else + dims_size = size(dst)[1:dims] + max_dims_idx = prod(dims_size) + max_idx = max_dims_idx * length(idx) + op, dst, src, idx, max_idx, max_dims_idx, dims_size + end + + kernel = @cuda launch=false scatter_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + return dst +end + +function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) + dst_ = NNlib.scatter!(+, zero(dst), src, idx) + dst .+= NNlib.safe_div.(dst_, Ns) + return dst +end diff --git a/test/runtests.jl b/test/runtests.jl index 4624151..778b875 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using NNlib using Zygote using NNlibCUDA using ForwardDiff: Dual +using Statistics: mean using CUDA CUDA.allowscalar(false) @@ -16,4 +17,5 @@ if CUDA.has_cuda() include("pooling.jl") include("softmax.jl") include("batchnorm.jl") + include("scatter.jl") end diff --git a/test/scatter.jl b/test/scatter.jl new file mode 100644 index 0000000..16b6e43 --- /dev/null +++ b/test/scatter.jl @@ -0,0 +1,147 @@ +dsts = Dict( + 0 => cu([3, 4, 5, 6, 7]), + 1 => cu([3 3 4 4 5; + 5 5 6 6 7]), +) +srcs = Dict( + (0, true) => cu(ones(Int, 3, 4)), + (0, false) => cu(ones(Int, 3) * collect(1:4)'), + (1, true) => cu(ones(Int, 2, 3, 4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), +) +idxs = [ + cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3]), # integer index + cu([(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), # tuple index +] +res = Dict( + (+, 0, true) => cu([5, 6, 9, 8, 9]), + (+, 1, true) => cu([5 5 8 6 7; + 7 7 10 8 9]), + (+, 0, false) => cu([4, 4, 12, 5, 5]), + (+, 1, false) => cu([4 4 12 5 5; + 8 8 24 10 10]), + (-, 0, true) => cu([1, 2, 1, 4, 5]), + (-, 1, true) => cu([1 1 0 2 3; + 3 3 2 4 5]), + (-, 0, false) => cu([-4, -4, -12, -5, -5]), + (-, 1, false) => cu([-4 -4 -12 -5 -5; + -8 -8 -24 -10 -10]), + (max, 0, true) => cu([3, 4, 5, 6, 7]), + (max, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (max, 0, false) => cu([3, 2, 4, 4, 3]), + (max, 1, false) => cu([3 2 4 4 3; + 6 4 8 8 6]), + (min, 0, true) => cu([1, 1, 1, 1, 1]), + (min, 1, true) => cu([1 1 1 1 1; + 1 1 1 1 1]), + (min, 0, false) => cu([1, 2, 1, 1, 2]), + (min, 1, false) => cu([1 2 1 1 2; + 2 4 2 2 4]), + (*, 0, true) => cu([3, 4, 5, 6, 7]), + (*, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (*, 0, false) => cu([3, 4, 48, 4, 6]), + (*, 1, false) => cu([3 4 48 4 6; + 12 16 768 16 24]), + (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]), + (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25; + 1.25 1.25 0.375 1.5 1.75]), + (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]), + (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6; + 1//12 1//16 1//768 1//16 1//24]), + (mean, 0, true) => cu([4., 5., 6., 7., 8.]), + (mean, 1, true) => cu([4. 4. 5. 5. 6.; + 6. 6. 7. 7. 8.]), + (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]), + (mean, 1, false) => cu([2. 2. 3. 2.5 2.5; + 4. 4. 6. 5. 5.]), +) + +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] + + +@testset "scatter" begin + for T = types + @testset "$(T)" begin + @testset "+" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(+, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) + + mutated = false + @test NNlib.scatter(+, T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) + end + end + + @testset "-" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(-, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) + + mutated = false + @test NNlib.scatter(-, T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) + end + end + + @testset "max" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(max, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) + + mutated = false + @test NNlib.scatter(max, T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) + end + end + + @testset "min" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(min, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) + + mutated = false + @test NNlib.scatter(min, T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) + end + end + end + end + + + for T = [CuArray{Float32}, CuArray{Float64}] + @testset "$(T)" begin + @testset "*" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(*, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) + + mutated = false + @test NNlib.scatter(*, T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) + end + end + + @testset "/" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(/, T(copy(dsts[dims])), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)]) + + mutated = false + @test NNlib.scatter(/, T(srcs[(dims, mutated)]), idx) == T(res[(/, dims, mutated)]) + end + end + + @testset "mean" begin + for idx = idxs, dims = [0, 1] + mutated = true + @test NNlib.scatter!(mean, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) + + mutated = false + @test NNlib.scatter(mean, T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) + end + end + end + end +end