@@ -2,37 +2,43 @@ ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.a
2
2
(* ) => CUDA. atomic_mul!, (/ ) => CUDA. atomic_div!, (& ) => CUDA. atomic_and!, (| ) => CUDA. atomic_or!)
3
3
4
4
for (op, atm_op) in ATM_OPS
5
- @eval function scatter_kernel! (op:: typeof ($ (op)), dst, src, idx, dims :: Val{0} , pre_rng, post_rng )
6
- li = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
5
+ @eval function scatter_kernel! (op:: typeof ($ (op)), dst, src, idx)
6
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
7
7
8
- @inbounds if li <= post_rng
9
- ind = CartesianIndices (idx)[li]
10
- dst_i = Base. _to_linear_index (dst, idx[li]. .. )
11
- $ (atm_op)(pointer (dst, dst_i), src[ind])
8
+ @inbounds if index <= length (idx)
9
+ i = Base. _to_linear_index (dst, idx[index]. .. )
10
+ $ (atm_op)(pointer (dst, i), src[index])
12
11
end
13
12
return nothing
14
13
end
15
14
16
- @eval function scatter_kernel! (op:: typeof ($ (op)), dst, src, idx, dims:: Val{1} , pre_rng, post_rng)
17
- li = threadIdx (). y + (blockIdx (). y - 1 ) * blockDim (). y
18
- i = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
15
+ @eval function scatter_kernel! (op:: typeof ($ (op)), dst, src, idx, dims:: Val{N} , max_idx, max_dims_idx, dims_size) where {N}
16
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
19
17
20
- @inbounds if li <= post_rng && i <= pre_rng
21
- j = CartesianIndices (idx)[li]
22
- dst_i = Base. _to_linear_index (dst, i, idx[li]. .. )
23
- $ (atm_op)(pointer (dst, dst_i), src[i, j])
18
+ @inbounds if index <= max_idx
19
+ j, k = divrem (index, max_dims_idx)
20
+ dims_i = CartesianIndices (dims_size)[k]
21
+ i = Base. _to_linear_index (dst, Tuple (dims_i)... , idx[j]. .. )
22
+ $ (atm_op)(pointer (dst, i), src[index])
24
23
end
25
24
return nothing
26
25
end
27
26
28
27
@eval function NNlib. scatter! (op:: typeof ($ (op)), dst:: CuArray{Tdst} , src:: CuArray{Tsrc} , idx:: CuArray{<:IntOrIntTuple} , dims:: Val{N} ) where {Tdst,Tsrc,N}
29
- pre_rng = prod (size (dst)[1 : N])
30
- post_rng = length (idx)
31
- thread_x = min (MAX_THREADS, pre_rng)
32
- thread_y = min (MAX_THREADS ÷ thread_x, post_rng)
33
- threads = (thread_x, thread_y)
34
- blocks = ceil .(Int, (pre_rng, post_rng) ./ threads)
35
- @cuda blocks= blocks threads= threads scatter_kernel! (op, dst, src, idx, dims, pre_rng, post_rng)
36
- return dst
28
+ if N == 0
29
+ max_idx = length (idx)
30
+ threads = min (MAX_THREADS, max_idx)
31
+ blocks = ceil (Int, max_idx / threads)
32
+ @cuda blocks= blocks threads= threads scatter_kernel! (op, dst, src, idx)
33
+ return dst
34
+ else
35
+ dims_size = size (dst)[1 : N]
36
+ max_dims_idx = prod (dims_size)
37
+ max_idx = max_dims_idx * length (idx)
38
+ threads = min (MAX_THREADS, max_idx)
39
+ blocks = ceil (Int, max_idx / threads)
40
+ @cuda blocks= blocks threads= threads scatter_kernel! (op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size)
41
+ return dst
42
+ end
37
43
end
38
44
end
0 commit comments