Skip to content

Commit 787fe7b

Browse files
committed
new indexing approach
1 parent 6c9096d commit 787fe7b

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

lib/NNlibCUDA/src/scatter.jl

+27-21
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,43 @@ ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.a
22
(*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!)
33

44
for (op, atm_op) in ATM_OPS
5-
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{0}, pre_rng, post_rng)
6-
li = threadIdx().x + (blockIdx().x - 1) * blockDim().x
5+
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx)
6+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
77

8-
@inbounds if li <= post_rng
9-
ind = CartesianIndices(idx)[li]
10-
dst_i = Base._to_linear_index(dst, idx[li]...)
11-
$(atm_op)(pointer(dst, dst_i), src[ind])
8+
@inbounds if index <= length(idx)
9+
i = Base._to_linear_index(dst, idx[index]...)
10+
$(atm_op)(pointer(dst, i), src[index])
1211
end
1312
return nothing
1413
end
1514

16-
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{1}, pre_rng, post_rng)
17-
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
18-
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
15+
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{N}, max_idx, max_dims_idx, dims_size) where {N}
16+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
1917

20-
@inbounds if li <= post_rng && i <= pre_rng
21-
j = CartesianIndices(idx)[li]
22-
dst_i = Base._to_linear_index(dst, i, idx[li]...)
23-
$(atm_op)(pointer(dst, dst_i), src[i, j])
18+
@inbounds if index <= max_idx
19+
j, k = divrem(index, max_dims_idx)
20+
dims_i = CartesianIndices(dims_size)[k]
21+
i = Base._to_linear_index(dst, Tuple(dims_i)..., idx[j]...)
22+
$(atm_op)(pointer(dst, i), src[index])
2423
end
2524
return nothing
2625
end
2726

2827
@eval function NNlib.scatter!(op::typeof($(op)), dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
29-
pre_rng = prod(size(dst)[1:N])
30-
post_rng = length(idx)
31-
thread_x = min(MAX_THREADS, pre_rng)
32-
thread_y = min(MAX_THREADS ÷ thread_x, post_rng)
33-
threads = (thread_x, thread_y)
34-
blocks = ceil.(Int, (pre_rng, post_rng) ./ threads)
35-
@cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx, dims, pre_rng, post_rng)
36-
return dst
28+
if N == 0
29+
max_idx = length(idx)
30+
threads = min(MAX_THREADS, max_idx)
31+
blocks = ceil(Int, max_idx / threads)
32+
@cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx)
33+
return dst
34+
else
35+
dims_size = size(dst)[1:N]
36+
max_dims_idx = prod(dims_size)
37+
max_idx = max_dims_idx * length(idx)
38+
threads = min(MAX_THREADS, max_idx)
39+
blocks = ceil(Int, max_idx / threads)
40+
@cuda blocks=blocks threads=threads scatter_kernel!(op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size)
41+
return dst
42+
end
3743
end
3844
end

0 commit comments

Comments
 (0)