Skip to content

Commit d36c18e

Browse files
committed
try to unify the scatter API
1 parent ce63fb3 commit d36c18e

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

lib/NNlibCUDA/src/scatter.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!,
22
(*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!)
33

4-
function scatter!(op, dst::CuArray, src::CuArray, idx::CuArray{IntOrIntTuple})
4+
function scatter!(op, dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
55
function kernel!(atm_op, dst, src, idx)
66
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
77
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

lib/NNlibCUDA/test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using NNlib
33
using Zygote
44
using NNlibCUDA
55
using ForwardDiff: Dual
6+
using Statistics: mean
67
using CUDA
78
CUDA.allowscalar(false)
89

0 commit comments

Comments
 (0)