From e14d0acf2219afc4b8b309a47e8fc70915153538 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Fri, 25 Dec 2020 16:57:50 +0800 Subject: [PATCH 1/3] add gather Co-authored-by: Carlo Lucibello --- src/NNlib.jl | 5 +-- src/gather.jl | 64 +++++++++++++++++++++++++++++++++ src/scatter.jl | 16 ++++----- test/gather.jl | 94 ++++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 +++ 5 files changed, 174 insertions(+), 10 deletions(-) create mode 100644 src/gather.jl create mode 100644 test/gather.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index 759622d7c..24658016d 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,7 +7,7 @@ import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Statistics: mean -const IntOrTuple = Union{Integer,Tuple} +const IntOrIntTuple = Union{Integer, NTuple{N,Integer} where N} const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} # Include APIs @@ -35,8 +35,9 @@ include("conv_bias_act.jl") include("pooling.jl") include("padding.jl") include("upsample.jl") -include("utils.jl") +include("gather.jl") include("scatter.jl") +include("utils.jl") ## Include implementations include("impl/padding_edges.jl") diff --git a/src/gather.jl b/src/gather.jl new file mode 100644 index 000000000..2229b455f --- /dev/null +++ b/src/gather.jl @@ -0,0 +1,64 @@ +export gather, gather! + +""" + gather!(dst, src, idx) + +Reverse operation of scatter. Gathers data from source `src` +and writes it in destination `dst` according to the index +array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] + +Notice that if `idx` is a vector containing integers, +and both `dst` and `src` are matrices, previous +expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will range over `1:length(idx)`. + +The elements of `idx` may be repeated. A single `src` column +can end up being copied into zero, one, or multiple `dst` columns. + +# Arguments + +- `dst`: the destination where data would be assigned to. +- `src`: the source of the data to be assigned. +- `idx`: the mapping from source to destination. +""" +function gather!(dst::AbstractArray{Tdst,Ndst}, + src::AbstractArray{Tsrc,Nsrc}, + idx::AbstractArray{Tidx, Nidx}) where + {Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple} + + M = typelength(Tidx) + Ndst - Nidx == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + + coldst = ntuple(i -> Colon(), Ndst - Nidx) + colsrc = ntuple(i -> Colon(), Nsrc - M) + for k in CartesianIndices(idx) + view(dst, coldst..., k) .= view(src, colsrc..., idx[k]...) + end + return dst +end + +""" + gather(src, idx) + +Non-mutating version of [`gather!`](@ref). +""" +function gather(src::AbstractArray{Tsrc, Nsrc}, + idx::AbstractArray{Tidx, Nidx}) where + {Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple} + + M = typelength(Tidx) + dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) + dst = similar(src, Tsrc, dstsize) + return gather!(dst, src, idx) +end + +typelength(::Type{<:Number}) = 1 +typelength(::Type{<:NTuple{M}}) where M = M diff --git a/src/scatter.jl b/src/scatter.jl index 3550574b3..e0d2a4eb9 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -47,13 +47,13 @@ Once the dimensions match, arrays are aligned automatically. The value of `idx` function scatter!(op, dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx} + idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx} M = typelength(Tidx) dims = _check_dims(Ndst, Nsrc, M, Nidx) scatter!(op, dst, src, idx, Val(dims)) end -function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple}, +function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N} colons = Base.ntuple(_->Colon(), dims) for k in CartesianIndices(idx) @@ -67,7 +67,7 @@ end function scatter!(op::typeof(mean), dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) dst .+= safe_div.(dst_, Ns) @@ -96,7 +96,7 @@ function scatter end for op in [+, -] @eval function scatter(op::typeof($op), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -108,7 +108,7 @@ end for op in [*, /] @eval function scatter(op::typeof($op), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -119,7 +119,7 @@ end function scatter(op::typeof(max), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -129,7 +129,7 @@ end function scatter(op::typeof(min), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) dst = similar(src, T, dstsize) @@ -139,7 +139,7 @@ end function scatter(op::typeof(mean), src::AbstractArray{T,Nsrc}, - idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} + idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx} FT = float(T) dims = Nsrc - Nidx dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) diff --git a/test/gather.jl b/test/gather.jl new file mode 100644 index 000000000..81d5cdeed --- /dev/null +++ b/test/gather.jl @@ -0,0 +1,94 @@ +@testset "gather scalar index" begin + T = Float32 + + ## 1d src, 2d index of ints -> 2d output + src = T[3, 4, 5, 6, 7] + index = [1 2 3 4; + 4 2 1 3; + 3 5 5 3] + output = T[3 4 5 6; + 6 4 3 5; + 5 7 7 5] + + y = gather(src, index) + @test y isa Array{T,2} + @test size(y) == size(index) + @test y == output + @test gather!(T.(zero(index)), src, index) == output + @test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) + + index2 = [1 2 3 4; + 4 2 1 3; + 3 6 5 3] + @test_throws BoundsError gather!(T.(zero(index)), src, index2) + + ## 1d src, 3d index of ints -> 3d output + src = T[3, 4, 5, 6, 7] + index = [1 2 3 4; + 4 2 1 3; + 3 5 5 3][:,:,1:1] + output = T[3 4 5 6; + 6 4 3 5; + 5 7 7 5][:,:,1:1] + + y = gather(src, index) + @test y isa Array{T,3} + @test size(y) == size(index) + @test y == output + + + ## 2d src, 2d index of ints -> 3d output + src = T[3 5 7 + 4 6 8] + index = [1 2 3; + 2 2 1; + 3 1 3] + + output = zeros(T, 2, 3, 3) + + output[:,:,1] = [3 5 7 + 4 6 8] + + output[:,:,2] = [5 5 3 + 6 6 4] + + output[:,:,3] = [7 3 7 + 8 4 8] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output +end + +@testset "gather tuple index" begin + T = Float32 + + ## 2d src, 1d index of 2-tuples -> 1d output + src = T[3 5 7 + 4 6 8] + + index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] + + output = T[3, 5, 7, 4, 6, 8] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5479581bc..077810caf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,10 @@ end include("upsample.jl") end +@testset "Gather" begin + include("gather.jl") +end + @testset "Scatter" begin include("scatter.jl") end @@ -53,3 +57,4 @@ end @testset "Utilities" begin include("utils.jl") end + From 920d7d1fa395a2d14a70909ccec81aebe4622a80 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 5 Mar 2021 08:10:21 +0100 Subject: [PATCH 2/3] more docs and cleanup --- Project.toml | 1 + src/gather.jl | 59 ++++++++++++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index 5b0923ad7..2985a7e02 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.14" [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/gather.jl b/src/gather.jl index 2229b455f..906ce15d6 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -3,29 +3,24 @@ export gather, gather! """ gather!(dst, src, idx) -Reverse operation of scatter. Gathers data from source `src` -and writes it in destination `dst` according to the index -array `idx`. +Reverse operation of [`scatter!`](@ref). Gathers data from source `src` +and writes it in destination `dst` according to the index array `idx`. For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to dst[:, ... , k] .= src[:, ... , idx[k]...] Notice that if `idx` is a vector containing integers, -and both `dst` and `src` are matrices, previous -expression simplifies to +and both `dst` and `src` are matrices, previous expression simplifies to dst[:, k] .= src[:, idx[k]] -and `k` will range over `1:length(idx)`. - -The elements of `idx` may be repeated. A single `src` column -can end up being copied into zero, one, or multiple `dst` columns. +and `k` will run over `1:length(idx)`. -# Arguments +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. -- `dst`: the destination where data would be assigned to. -- `src`: the source of the data to be assigned. -- `idx`: the mapping from source to destination. +See [`gather`](@ref) for an allocating version. """ function gather!(dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, @@ -33,22 +28,41 @@ function gather!(dst::AbstractArray{Tdst,Ndst}, {Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple} M = typelength(Tidx) - Ndst - Nidx == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) - size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("Incompatible input shapes.")) - size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + d = Ndst - Nidx + d == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - coldst = ntuple(i -> Colon(), Ndst - Nidx) - colsrc = ntuple(i -> Colon(), Nsrc - M) + colons = ntuple(i -> Colon(), d) for k in CartesianIndices(idx) - view(dst, coldst..., k) .= view(src, colsrc..., idx[k]...) + view(dst, colons..., k) .= view(src, colons..., idx[k]...) end return dst end """ - gather(src, idx) + gather(src, idx) -> dst + +Reverse operation of [`scatter`](@ref). Gathers data from source `src` +and writes it in a destination `dst` according to the index +array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` +according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] -Non-mutating version of [`gather!`](@ref). +Notice that if `idx` is a vector containing integers +and `src` is a matrix, previous expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will run over `1:length(idx)`. + +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. + +See [`gather!`](@ref) for an in-place version. """ function gather(src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{Tidx, Nidx}) where @@ -59,6 +73,3 @@ function gather(src::AbstractArray{Tsrc, Nsrc}, dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end - -typelength(::Type{<:Number}) = 1 -typelength(::Type{<:NTuple{M}}) where M = M From 73079044dd9180d48686819196b58f046e35b91c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 5 Mar 2021 09:17:23 +0100 Subject: [PATCH 3/3] cleanup; getindex --- Project.toml | 1 - src/NNlib.jl | 2 +- src/gather.jl | 9 +++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2985a7e02..5b0923ad7 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.14" [deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/NNlib.jl b/src/NNlib.jl index 24658016d..647eb400a 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,7 +7,7 @@ import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Statistics: mean -const IntOrIntTuple = Union{Integer, NTuple{N,Integer} where N} +const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} # Include APIs diff --git a/src/gather.jl b/src/gather.jl index 906ce15d6..46984f180 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -73,3 +73,12 @@ function gather(src::AbstractArray{Tsrc, Nsrc}, dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end + +# Simple implementation with getindex for integer array. +# Perf equivalent to the one above (which can also handle the integer case) +# leave it here to show the simple connection with getindex. +function gather(src::AbstractArray{Tsrc, Nsrc}, + idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc} + colons = ntuple(i -> Colon(), Nsrc-1) + return src[colons..., idx] +end