Skip to content

Commit 90804b0

Browse files
committedApr 15, 2021
refactor APIs and _check_dims
1 parent af42a90 commit 90804b0

File tree

2 files changed

+27
-58
lines changed

2 files changed

+27
-58
lines changed
 

‎src/gather.jl

+3-28
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,9 @@ or multiple `dst` columns.
2020
2121
See [`gather`](@ref) for an allocating version.
2222
"""
23-
function gather!(dst::AbstractArray{Tdst,Ndst},
24-
src::AbstractArray{Tsrc,Nsrc},
25-
idx::AbstractArray{Tidx,Nidx}) where {Tdst, Tsrc, Tidx<:IntOrIntTuple, Ndst, Nsrc, Nidx}
26-
27-
M = typelength(Tidx)
28-
d = Ndst - Nidx
29-
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
30-
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
31-
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
32-
33-
colons = ntuple(i -> Colon(), d)
34-
for k in CartesianIndices(idx)
35-
_view(dst, colons, k) .= _view(src, colons, idx, k)
36-
end
37-
return dst
38-
end
39-
40-
function gather!(dst::AbstractArray{Tdst,Ndst},
41-
src::AbstractArray{Tsrc,Nsrc},
42-
idx::AbstractArray{CartesianIndex{M},Nidx}) where
43-
{Tdst, Tsrc, Ndst, Nsrc, M, Nidx}
44-
45-
d = Ndst - Nidx
46-
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
47-
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
48-
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
49-
50-
colons = ntuple(i -> Colon(), d)
23+
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
24+
dims = _check_dims(src, dst, idx)
25+
colons = ntuple(i -> Colon(), dims)
5126
for k in CartesianIndices(idx)
5227
_view(dst, colons, k) .= _view(src, colons, idx, k)
5328
end

‎src/scatter.jl

+24-30
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,27 @@
1212
typelength(::Type{<:Number}) = 1
1313
typelength(::Type{<:NTuple{M}}) where M = M
1414

15-
function _check_dims(Ndst, Nsrc, N, Nidx)
16-
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)."
17-
dims = Ndst - N
18-
if dims < 0
19-
throw(ArgumentError("dims must be non-negative but got dims=$dims."))
20-
end
15+
function _check_dims(X::AbstractArray{Tx,Nx},
16+
Y::AbstractArray{Ty,Ny},
17+
idx::AbstractArray{Tidx,Nidx}) where
18+
{Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx}
19+
M = typelength(Tidx)
20+
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (X, Y, idx) = ($Nx, $Ny, $Nidx)."
21+
dims = Nx - M
22+
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
23+
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
24+
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
25+
return dims
26+
end
27+
28+
function _check_dims(X::AbstractArray{Tx,Nx},
29+
Y::AbstractArray{Ty,Ny},
30+
idx::AbstractArray{CartesianIndex{M},Nidx}) where {Tx,Ty,Nx,Ny,M,Nidx}
31+
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
32+
dims = Nx - M
33+
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
34+
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
35+
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
2136
return dims
2237
end
2338

@@ -46,26 +61,8 @@ index of `dst` and the value of `idx` must indicate the last few dimensions of `
4661
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
4762
`Int` or `Tuple` type.
4863
"""
49-
function scatter!(op,
50-
dst::AbstractArray{Tdst,Ndst},
51-
src::AbstractArray{Tsrc,Nsrc},
52-
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
53-
M = typelength(Tidx)
54-
dims = _check_dims(Ndst, Nsrc, M, Nidx)
55-
colons = Base.ntuple(_->Colon(), dims)
56-
for k in CartesianIndices(idx)
57-
dst_v = _view(dst, colons, idx, k)
58-
src_v = _view(src, colons, k)
59-
dst_v .= (op).(dst_v, src_v)
60-
end
61-
dst
62-
end
63-
64-
function scatter!(op,
65-
dst::AbstractArray{Tdst,Ndst},
66-
src::AbstractArray{Tsrc,Nsrc},
67-
idx::AbstractArray{CartesianIndex{M},Nidx}) where {Tdst,Ndst,Tsrc,Nsrc,M,Nidx}
68-
dims = _check_dims(Ndst, Nsrc, M, Nidx)
64+
function scatter!(op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
65+
dims = _check_dims(dst, src, idx)
6966
colons = Base.ntuple(_->Colon(), dims)
7067
for k in CartesianIndices(idx)
7168
dst_v = _view(dst, colons, idx, k)
@@ -75,10 +72,7 @@ function scatter!(op,
7572
dst
7673
end
7774

78-
function scatter!(op::typeof(mean),
79-
dst::AbstractArray{Tdst,Ndst},
80-
src::AbstractArray{Tsrc,Nsrc},
81-
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
75+
function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
8276
Ns = scatter!(+, zero(dst), one.(src), idx)
8377
dst_ = scatter!(+, zero(dst), src, idx)
8478
dst .+= safe_div.(dst_, Ns)

0 commit comments

Comments
 (0)