12
12
typelength (:: Type{<:Number} ) = 1
13
13
typelength (:: Type{<:NTuple{M}} ) where M = M
14
14
15
- function _check_dims (Ndst, Nsrc, N, Nidx)
16
- @assert Ndst - N == Nsrc - Nidx " Incompatible input shapes of (dst, src, idx) = ($Ndst , $Nsrc , $Nidx )."
17
- dims = Ndst - N
18
- if dims < 0
19
- throw (ArgumentError (" dims must be non-negative but got dims=$dims ." ))
20
- end
15
+ function _check_dims (X:: AbstractArray{Tx,Nx} ,
16
+ Y:: AbstractArray{Ty,Ny} ,
17
+ idx:: AbstractArray{Tidx,Nidx} ) where
18
+ {Tx,Ty,Tidx<: IntOrIntTuple ,Nx,Ny,Nidx}
19
+ M = typelength (Tidx)
20
+ @assert Nx - M == Ny - Nidx " Incompatible input shapes of (X, Y, idx) = ($Nx , $Ny , $Nidx )."
21
+ dims = Nx - M
22
+ dims < 0 && throw (ArgumentError (" dims must be non-negative but got dims=$dims ." ))
23
+ size (X)[1 : dims] == size (Y)[1 : dims] || throw (ArgumentError (" Incompatible input shapes." ))
24
+ size (Y)[dims+ 1 : end ] == size (idx) || throw (ArgumentError (" Incompatible input shapes." ))
25
+ return dims
26
+ end
27
+
28
+ function _check_dims (X:: AbstractArray{Tx,Nx} ,
29
+ Y:: AbstractArray{Ty,Ny} ,
30
+ idx:: AbstractArray{CartesianIndex{M},Nidx} ) where {Tx,Ty,Nx,Ny,M,Nidx}
31
+ @assert Nx - M == Ny - Nidx " Incompatible input shapes of (dst, src, idx) = ($Nx , $Ny , $Nidx )."
32
+ dims = Nx - M
33
+ dims < 0 && throw (ArgumentError (" dims must be non-negative but got dims=$dims ." ))
34
+ size (X)[1 : dims] == size (Y)[1 : dims] || throw (ArgumentError (" Incompatible input shapes." ))
35
+ size (Y)[dims+ 1 : end ] == size (idx) || throw (ArgumentError (" Incompatible input shapes." ))
21
36
return dims
22
37
end
23
38
@@ -46,26 +61,8 @@ index of `dst` and the value of `idx` must indicate the last few dimensions of `
46
61
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
47
62
`Int` or `Tuple` type.
48
63
"""
49
- function scatter! (op,
50
- dst:: AbstractArray{Tdst,Ndst} ,
51
- src:: AbstractArray{Tsrc,Nsrc} ,
52
- idx:: AbstractArray{Tidx,Nidx} ) where {Tdst,Tsrc,Tidx<: IntOrIntTuple ,Ndst,Nsrc,Nidx}
53
- M = typelength (Tidx)
54
- dims = _check_dims (Ndst, Nsrc, M, Nidx)
55
- colons = Base. ntuple (_-> Colon (), dims)
56
- for k in CartesianIndices (idx)
57
- dst_v = _view (dst, colons, idx, k)
58
- src_v = _view (src, colons, k)
59
- dst_v .= (op). (dst_v, src_v)
60
- end
61
- dst
62
- end
63
-
64
- function scatter! (op,
65
- dst:: AbstractArray{Tdst,Ndst} ,
66
- src:: AbstractArray{Tsrc,Nsrc} ,
67
- idx:: AbstractArray{CartesianIndex{M},Nidx} ) where {Tdst,Ndst,Tsrc,Nsrc,M,Nidx}
68
- dims = _check_dims (Ndst, Nsrc, M, Nidx)
64
+ function scatter! (op, dst:: AbstractArray , src:: AbstractArray , idx:: AbstractArray )
65
+ dims = _check_dims (dst, src, idx)
69
66
colons = Base. ntuple (_-> Colon (), dims)
70
67
for k in CartesianIndices (idx)
71
68
dst_v = _view (dst, colons, idx, k)
@@ -75,10 +72,7 @@ function scatter!(op,
75
72
dst
76
73
end
77
74
78
- function scatter! (op:: typeof (mean),
79
- dst:: AbstractArray{Tdst,Ndst} ,
80
- src:: AbstractArray{Tsrc,Nsrc} ,
81
- idx:: AbstractArray{<:IntOrIntTuple,Nidx} ) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
75
+ function scatter! (op:: typeof (mean), dst:: AbstractArray , src:: AbstractArray , idx:: AbstractArray )
82
76
Ns = scatter! (+ , zero (dst), one .(src), idx)
83
77
dst_ = scatter! (+ , zero (dst), src, idx)
84
78
dst .+ = safe_div .(dst_, Ns)
0 commit comments