From 60585e2d660f56005d173492e91b54458d481d1d Mon Sep 17 00:00:00 2001
From: Yueh-Hua Tu <a504082002@gmail.com>
Date: Wed, 9 Jun 2021 18:31:57 +0800
Subject: [PATCH] support scatter for cuda gradient
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

add count_indices for cuarray

add CUDA kernel for divide_by_counts!

add NNlib.∇scatter_src for cuda gradient

support scatter mean AD for CUDA

support scatter *,/ AD for CUDA
---
 Manifest.toml    | 64 ++++++++++++++++--------------------------
 Project.toml     |  2 +-
 src/NNlibCUDA.jl |  1 +
 src/scatter.jl   | 72 ++++++++++++++++++++++++++++++++++++++++++++++++
 src/utils.jl     | 56 +++++++++++++++++++++++++++++++++++++
 test/scatter.jl  | 72 ++++++++++--------------------------------------
 6 files changed, 167 insertions(+), 100 deletions(-)
 create mode 100644 src/utils.jl

diff --git a/Manifest.toml b/Manifest.toml
index 22b48b9..54291a3 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
 version = "0.4.1"
 
 [[CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
-git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
+git-tree-sha1 = "82b2811f5888465d96b38c7bb12d8fb9c25838e1"
 uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "3.2.1"
+version = "3.3.1"
 
 [[ChainRulesCore]]
 deps = ["Compat", "LinearAlgebra", "SparseArrays"]
-git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
+git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
 uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-version = "0.10.2"
+version = "0.10.9"
 
 [[Compat]]
 deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
-git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
+git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
 uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
-version = "3.30.0"
+version = "3.31.0"
 
 [[CompilerSupportLibraries_jll]]
 deps = ["Artifacts", "Libdl"]
@@ -73,10 +73,10 @@ deps = ["Random", "Serialization", "Sockets"]
 uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
 
 [[DocStringExtensions]]
-deps = ["LibGit2", "Markdown", "Pkg", "Test"]
-git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
+deps = ["LibGit2"]
+git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
 uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
-version = "0.8.4"
+version = "0.8.5"
 
 [[Downloads]]
 deps = ["ArgTools", "LibCURL", "NetworkOptions"]
@@ -89,15 +89,15 @@ version = "0.1.3"
 
 [[GPUArrays]]
 deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
-git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
+git-tree-sha1 = "ececbf05f8904c92814bdbd0aafd5540b0bf2e9a"
 uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "6.4.1"
+version = "7.0.1"
 
 [[GPUCompiler]]
-deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
+deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
+git-tree-sha1 = "222c6cdb888ec24795936d6829aa978691def60e"
 uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.11.5"
+version = "0.12.3"
 
 [[InteractiveUtils]]
 deps = ["Markdown"]
@@ -111,9 +111,9 @@ version = "1.3.0"
 
 [[LLVM]]
 deps = ["CEnum", "Libdl", "Printf", "Unicode"]
-git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
+git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e"
 uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "3.7.1"
+version = "3.9.0"
 
 [[LazyArtifacts]]
 deps = ["Artifacts", "Pkg"]
@@ -151,12 +151,6 @@ version = "0.2.4"
 [[Logging]]
 uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
 
-[[MacroTools]]
-deps = ["Markdown", "Random"]
-git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
-uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
-version = "0.5.6"
-
 [[Markdown]]
 deps = ["Base64"]
 uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -165,12 +159,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
 deps = ["Artifacts", "Libdl"]
 uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
 
-[[Memoize]]
-deps = ["MacroTools"]
-git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa"
-uuid = "c03570c3-d221-55d1-a50c-7939bbd78826"
-version = "0.4.4"
-
 [[Mmap]]
 uuid = "a63ad114-7e13-5084-954f-fe012c677804"
 
@@ -179,9 +167,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
 
 [[NNlib]]
 deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
-git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
+git-tree-sha1 = "7e6f31cfa39b1ff1c541cc8580b14b0ff4ba22d0"
 uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.7.21"
+version = "0.7.23"
 
 [[NetworkOptions]]
 uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -221,9 +209,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 
 [[Random123]]
 deps = ["Libdl", "Random", "RandomNumbers"]
-git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
+git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
 uuid = "74087812-796a-5b5d-8853-05524746bad3"
-version = "1.3.1"
+version = "1.4.2"
 
 [[RandomNumbers]]
 deps = ["Random", "Requires"]
@@ -245,12 +233,6 @@ version = "1.1.3"
 [[SHA]]
 uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
 
-[[Scratch]]
-deps = ["Dates"]
-git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
-uuid = "6c6a2e73-6563-6170-7368-637461726353"
-version = "1.0.3"
-
 [[Serialization]]
 uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
 
@@ -289,9 +271,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [[TimerOutputs]]
 deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
+git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1"
 uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.9"
+version = "0.5.10"
 
 [[UUIDs]]
 deps = ["Random", "SHA"]
diff --git a/Project.toml b/Project.toml
index a2d48f6..6eebadd 100644
--- a/Project.toml
+++ b/Project.toml
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 
 [compat]
 CUDA = "3.3.1"
-NNlib = "0.7.21"
+NNlib = "0.7.23"
 julia = "1.6"
 
 [extras]
diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl
index f98c08f..0fc5626 100644
--- a/src/NNlibCUDA.jl
+++ b/src/NNlibCUDA.jl
@@ -11,6 +11,7 @@ include("activations.jl")
 include("batchedmul.jl")
 include("scatter.jl")
 include("gather.jl")
+include("utils.jl")
 include("cudnn/cudnn.jl")
 include("cudnn/conv.jl")
 include("cudnn/pooling.jl")
diff --git a/src/scatter.jl b/src/scatter.jl
index d08fc32..122b85f 100644
--- a/src/scatter.jl
+++ b/src/scatter.jl
@@ -46,3 +46,75 @@ function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx:
     dst .+= NNlib.safe_div.(dst_, Ns)
     return dst
 end
+
+
+## Gradients
+
+function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
+    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
+
+    @inbounds if index <= max_idx
+        cart_j = CartesianIndices(idx)[index]
+        # get aggregating indeices, which is to be aggregated together, and itself index
+        inds = rev_idx[idx[cart_j]...]
+        # multiply all values to be aggregated but not itself
+        x = one(T)
+        for k in inds
+            x *= src[k]
+        end
+        x /= src[cart_j]
+        # apply `op` on `Δsrc[i, k]` and `x`
+        Δsrc[cart_j] = op(Δsrc[cart_j], x)
+    end
+    return nothing
+end
+
+function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
+    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
+
+    @inbounds if index <= max_idx
+        i, j = fldmod1(index, max_dims_idx)
+        cart_i = CartesianIndices(idx)[i]
+        cart_j = pre_cart_idx[j]
+        # get aggregating indeices, which is to be aggregated together, and itself index
+        inds = rev_idx[idx[cart_i]...]
+        # multiply all values to be aggregated but not itself
+        x = one(T)
+        for k in inds
+            jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
+            x *= src[jk]
+        end
+        x /= src[index]
+        # apply `op` on `Δsrc[i, k]` and `x`
+        Δsrc[index] = op(Δsrc[index], x)
+    end
+    return nothing
+end
+
+function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
+                            src::AnyCuArray{Tsrc,Nsrc}, 
+                            idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
+    dims = Nsrc - Nidx
+    Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
+    rev_idx = NNlib.reverse_indices(idx)
+    rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))
+    
+    if dims == 0
+        max_idx = length(idx)
+        args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
+    else
+        pre_cart_idx = CartesianIndices(axes(src)[1:dims])
+        max_dims_idx = length(pre_cart_idx)
+        max_idx = max_dims_idx * length(idx)
+        args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc
+    end
+
+    kernel = @cuda launch=false ∇scatter_src_kernel!(args...)
+    config = launch_configuration(kernel.fun; max_threads=256)
+    threads = min(max_idx, config.threads)
+    blocks = cld(max_idx, threads)
+    kernel(args...; threads=threads, blocks=blocks)
+
+    CUDA.unsafe_free!(rev_idx)
+    return Δsrc
+end
diff --git a/src/utils.jl b/src/utils.jl
new file mode 100644
index 0000000..1139d67
--- /dev/null
+++ b/src/utils.jl
@@ -0,0 +1,56 @@
+function NNlib.count_indices(idx::AnyCuArray)
+    dst_counts = length.(NNlib.reverse_indices(idx))
+    src_counts = NNlib.gather(cu(dst_counts), idx)
+    return src_counts
+end
+
+function divide_kernel!(xs, ys, max_idx)
+    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
+
+    @inbounds if index <= max_idx
+        xs[index] = xs[index] / ys[index]
+    end
+    return nothing
+end
+
+function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)
+    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
+
+    @inbounds if index <= max_idx
+        j, k = divrem(index-1, max_dims_idx)
+        dims_i = Tuple(CartesianIndices(dims_size)[k+1])
+        @atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]
+    end
+    return nothing
+end
+
+function NNlib.divide_by_counts!(xs::AnyCuArray{T}, idx::AnyCuArray, dims) where {T}
+    counts = CuArray{T}(NNlib.count_indices(idx))
+    args = if dims == 0
+        max_idx = length(idx)
+        xs, counts, max_idx
+    else
+        dims_size = size(xs)[1:dims]
+        max_dims_idx = prod(dims_size)
+        max_idx = prod(size(xs))
+        xs, counts, max_idx, max_dims_idx, dims_size
+    end
+
+    kernel = @cuda launch=false divide_kernel!(args...)
+    config = launch_configuration(kernel.fun; max_threads=256)
+    threads = min(max_idx, config.threads)
+    blocks = cld(max_idx, threads)
+    kernel(args...; threads=threads, blocks=blocks)
+    return xs
+end
+
+function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
+    max_dims = NNlib.maximum_dims(idx)
+    T = CartesianIndex{N}
+    rev = Array{Vector{T}}(undef, max_dims...)
+    for i in eachindex(rev)
+        rev[i] = T[]
+    end
+    NNlib.reverse_indices!(rev, idx)
+    return map(cu, rev)
+end
diff --git a/test/scatter.jl b/test/scatter.jl
index 16b6e43..088284e 100644
--- a/test/scatter.jl
+++ b/test/scatter.jl
@@ -17,50 +17,6 @@ idxs = [
         (4,) (2,) (1,) (3,);
         (3,) (5,) (5,) (3,)]),  # tuple index
 ]
-res = Dict(
-    (+, 0, true) => cu([5, 6, 9, 8, 9]),
-    (+, 1, true) => cu([5 5 8 6 7;
-                        7 7 10 8 9]),
-    (+, 0, false) => cu([4, 4, 12, 5, 5]),
-    (+, 1, false) => cu([4 4 12 5 5;
-                         8 8 24 10 10]),
-    (-, 0, true) => cu([1, 2, 1, 4, 5]),
-    (-, 1, true) => cu([1 1 0 2 3;
-                        3 3 2 4 5]),
-    (-, 0, false) => cu([-4, -4, -12, -5, -5]),
-    (-, 1, false) => cu([-4 -4 -12 -5 -5;
-                         -8 -8 -24 -10 -10]),
-    (max, 0, true) => cu([3, 4, 5, 6, 7]),
-    (max, 1, true) => cu([3 3 4 4 5;
-                          5 5 6 6 7]),
-    (max, 0, false) => cu([3, 2, 4, 4, 3]),
-    (max, 1, false) => cu([3 2 4 4 3;
-                           6 4 8 8 6]),
-    (min, 0, true) => cu([1, 1, 1, 1, 1]),
-    (min, 1, true) => cu([1 1 1 1 1;
-                          1 1 1 1 1]),
-    (min, 0, false) => cu([1, 2, 1, 1, 2]),
-    (min, 1, false) => cu([1 2 1 1 2;
-                           2 4 2 2 4]),
-    (*, 0, true) => cu([3, 4, 5, 6, 7]),
-    (*, 1, true) => cu([3 3 4 4 5;
-                        5 5 6 6 7]),
-    (*, 0, false) => cu([3, 4, 48, 4, 6]),
-    (*, 1, false) => cu([3 4 48 4 6;
-                        12 16 768 16 24]),
-    (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
-    (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
-                        1.25 1.25 0.375 1.5 1.75]),
-    (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
-    (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
-                         1//12 1//16 1//768 1//16 1//24]),
-    (mean, 0, true) => cu([4., 5., 6., 7., 8.]),
-    (mean, 1, true) => cu([4. 4. 5. 5. 6.;
-                           6. 6. 7. 7. 8.]),
-    (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
-    (mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
-                            4. 4. 6. 5. 5.]),
-)
 
 types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]
 
@@ -71,40 +27,40 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]
             @testset "+" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(+, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(+, T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])
+                    gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
 
             @testset "-" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(-, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(-, T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])
+                    gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
 
             @testset "max" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(max, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(max, T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])
+                    gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
 
             @testset "min" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(min, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(min, T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])
+                    gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
         end
@@ -116,30 +72,30 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]
             @testset "*" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(*, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(*, T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])
+                    gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
 
             @testset "/" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(/, T(copy(dsts[dims])), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(/, T(srcs[(dims, mutated)]), idx) == T(res[(/, dims, mutated)])
+                    gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
 
             @testset "mean" begin
                 for idx = idxs, dims = [0, 1]
                     mutated = true
-                    @test NNlib.scatter!(mean, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])
+                    gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
 
                     mutated = false
-                    @test NNlib.scatter(mean, T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])
+                    gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
                 end
             end
         end