diff --git a/src/NNlib.jl b/src/NNlib.jl index 759622d7c..647eb400a 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,7 +7,7 @@ import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Statistics: mean -const IntOrTuple = Union{Integer,Tuple} +const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} # Include APIs @@ -35,8 +35,9 @@ include("conv_bias_act.jl") include("pooling.jl") include("padding.jl") include("upsample.jl") -include("utils.jl") +include("gather.jl") include("scatter.jl") +include("utils.jl") ## Include implementations include("impl/padding_edges.jl") diff --git a/src/gather.jl b/src/gather.jl new file mode 100644 index 000000000..46984f180 --- /dev/null +++ b/src/gather.jl @@ -0,0 +1,84 @@ +export gather, gather! + +""" + gather!(dst, src, idx) + +Reverse operation of [`scatter!`](@ref). Gathers data from source `src` +and writes it in destination `dst` according to the index array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] + +Notice that if `idx` is a vector containing integers, +and both `dst` and `src` are matrices, previous expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will run over `1:length(idx)`. + +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. + +See [`gather`](@ref) for an allocating version. +""" +function gather!(dst::AbstractArray{Tdst,Ndst}, + src::AbstractArray{Tsrc,Nsrc}, + idx::AbstractArray{Tidx, Nidx}) where + {Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple} + + M = typelength(Tidx) + d = Ndst - Nidx + d == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + + colons = ntuple(i -> Colon(), d) + for k in CartesianIndices(idx) + view(dst, colons..., k) .= view(src, colons..., idx[k]...) + end + return dst +end + +""" + gather(src, idx) -> dst + +Reverse operation of [`scatter`](@ref). Gathers data from source `src` +and writes it in a destination `dst` according to the index +array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` +according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] + +Notice that if `idx` is a vector containing integers +and `src` is a matrix, previous expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will run over `1:length(idx)`. + +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. + +See [`gather!`](@ref) for an in-place version. +""" +function gather(src::AbstractArray{Tsrc, Nsrc}, + idx::AbstractArray{Tidx, Nidx}) where + {Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple} + + M = typelength(Tidx) + dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) + dst = similar(src, Tsrc, dstsize) + return gather!(dst, src, idx) +end + +# Simple implementation with getindex for integer array. +# Perf equivalent to the one above (which can also handle the integer case) +# leave it here to show the simple connection with getindex. +function gather(src::AbstractArray{Tsrc, Nsrc}, + idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc} + colons = ntuple(i -> Colon(), Nsrc-1) + return src[colons..., idx] +end diff --git a/src/scatter.jl b/src/scatter.jl index 3550574b3..e0d2a4eb9 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -47,13 +47,13 @@ Once the dimensions match, arrays are aligned automatically. The value of `idx` function scatter!(op, dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx} + idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx} M = typelength(Tidx) dims = _check_dims(Ndst, Nsrc, M, Nidx) scatter!(op, dst, src, idx, Val(dims)) end -function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple}, +function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N} colons = Base.ntuple(_->Colon(), dims) for k in CartesianIndices(idx) @@ -67,7 +67,7 @@ end function scatter!(op::typeof(mean), dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) dst .+= safe_div.(dst_, Ns) @@ -96,7 +96,7 @@ function scatter end for op in [+, -] @eval function scatter(op::typeof($op), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -108,7 +108,7 @@ end for op in [*, /] @eval function scatter(op::typeof($op), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -119,7 +119,7 @@ end function scatter(op::typeof(max), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -129,7 +129,7 @@ end function scatter(op::typeof(min), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -139,7 +139,7 @@ end function scatter(op::typeof(mean), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} FT = float(T) dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) diff --git a/test/gather.jl b/test/gather.jl new file mode 100644 index 000000000..81d5cdeed --- /dev/null +++ b/test/gather.jl @@ -0,0 +1,94 @@ +@testset "gather scalar index" begin + T = Float32 + + ## 1d src, 2d index of ints -> 2d output + src = T[3, 4, 5, 6, 7] + index = [1 2 3 4; + 4 2 1 3; + 3 5 5 3] + output = T[3 4 5 6; + 6 4 3 5; + 5 7 7 5] + + y = gather(src, index) + @test y isa Array{T,2} + @test size(y) == size(index) + @test y == output + @test gather!(T.(zero(index)), src, index) == output + @test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) + + index2 = [1 2 3 4; + 4 2 1 3; + 3 6 5 3] + @test_throws BoundsError gather!(T.(zero(index)), src, index2) + + ## 1d src, 3d index of ints -> 3d output + src = T[3, 4, 5, 6, 7] + index = [1 2 3 4; + 4 2 1 3; + 3 5 5 3][:,:,1:1] + output = T[3 4 5 6; + 6 4 3 5; + 5 7 7 5][:,:,1:1] + + y = gather(src, index) + @test y isa Array{T,3} + @test size(y) == size(index) + @test y == output + + + ## 2d src, 2d index of ints -> 3d output + src = T[3 5 7 + 4 6 8] + index = [1 2 3; + 2 2 1; + 3 1 3] + + output = zeros(T, 2, 3, 3) + + output[:,:,1] = [3 5 7 + 4 6 8] + + output[:,:,2] = [5 5 3 + 6 6 4] + + output[:,:,3] = [7 3 7 + 8 4 8] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output +end + +@testset "gather tuple index" begin + T = Float32 + + ## 2d src, 1d index of 2-tuples -> 1d output + src = T[3 5 7 + 4 6 8] + + index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] + + output = T[3, 5, 7, 4, 6, 8] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5479581bc..077810caf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,10 @@ end include("upsample.jl") end +@testset "Gather" begin + include("gather.jl") +end + @testset "Scatter" begin include("scatter.jl") end @@ -53,3 +57,4 @@ end @testset "Utilities" begin include("utils.jl") end +