Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

add gather #280

Merged
merged 3 commits into from
Mar 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Statistics: mean

const IntOrTuple = Union{Integer,Tuple}
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
@@ -35,8 +35,9 @@ include("conv_bias_act.jl")
include("pooling.jl")
include("padding.jl")
include("upsample.jl")
include("utils.jl")
include("gather.jl")
include("scatter.jl")
include("utils.jl")

## Include implementations
include("impl/padding_edges.jl")
84 changes: 84 additions & 0 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
export gather, gather!

"""
gather!(dst, src, idx)

Reverse operation of [`scatter!`](@ref). Gathers data from source `src`
and writes it in destination `dst` according to the index array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to

dst[:, ... , k] .= src[:, ... , idx[k]...]

Notice that if `idx` is a vector containing integers,
and both `dst` and `src` are matrices, previous expression simplifies to

dst[:, k] .= src[:, idx[k]]

and `k` will run over `1:length(idx)`.

The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.

See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple}

M = typelength(Tidx)
d = Ndst - Nidx
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))

colons = ntuple(i -> Colon(), d)
for k in CartesianIndices(idx)
view(dst, colons..., k) .= view(src, colons..., idx[k]...)
end
return dst
end

"""
gather(src, idx) -> dst

Reverse operation of [`scatter`](@ref). Gathers data from source `src`
and writes it in a destination `dst` according to the index
array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst`
according to

dst[:, ... , k] .= src[:, ... , idx[k]...]

Notice that if `idx` is a vector containing integers
and `src` is a matrix, previous expression simplifies to

dst[:, k] .= src[:, idx[k]]

and `k` will run over `1:length(idx)`.

The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.

See [`gather!`](@ref) for an in-place version.
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
dst = similar(src, Tsrc, dstsize)
return gather!(dst, src, idx)
end

# Simple implementation with getindex for integer array.
# Perf equivalent to the one above (which can also handle the integer case)
# leave it here to show the simple connection with getindex.
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc}
colons = ntuple(i -> Colon(), Nsrc-1)
return src[colons..., idx]
end
16 changes: 8 additions & 8 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -47,13 +47,13 @@ Once the dimensions match, arrays are aligned automatically. The value of `idx`
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx}
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple},
function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
@@ -67,7 +67,7 @@ end
function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
@@ -96,7 +96,7 @@ function scatter end
for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
@@ -108,7 +108,7 @@ end
for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
@@ -119,7 +119,7 @@ end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
@@ -129,7 +129,7 @@ end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
@@ -139,7 +139,7 @@ end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
94 changes: 94 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
@testset "gather scalar index" begin
T = Float32

## 1d src, 2d index of ints -> 2d output
src = T[3, 4, 5, 6, 7]
index = [1 2 3 4;
4 2 1 3;
3 5 5 3]
output = T[3 4 5 6;
6 4 3 5;
5 7 7 5]

y = gather(src, index)
@test y isa Array{T,2}
@test size(y) == size(index)
@test y == output
@test gather!(T.(zero(index)), src, index) == output
@test_throws ArgumentError gather!(zeros(T, 3, 5), src, index)

index2 = [1 2 3 4;
4 2 1 3;
3 6 5 3]
@test_throws BoundsError gather!(T.(zero(index)), src, index2)

## 1d src, 3d index of ints -> 3d output
src = T[3, 4, 5, 6, 7]
index = [1 2 3 4;
4 2 1 3;
3 5 5 3][:,:,1:1]
output = T[3 4 5 6;
6 4 3 5;
5 7 7 5][:,:,1:1]

y = gather(src, index)
@test y isa Array{T,3}
@test size(y) == size(index)
@test y == output


## 2d src, 2d index of ints -> 3d output
src = T[3 5 7
4 6 8]
index = [1 2 3;
2 2 1;
3 1 3]

output = zeros(T, 2, 3, 3)

output[:,:,1] = [3 5 7
4 6 8]

output[:,:,2] = [5 5 3
6 6 4]

output[:,:,3] = [7 3 7
8 4 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output
end

@testset "gather tuple index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]

index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -46,10 +46,15 @@ end
include("upsample.jl")
end

@testset "Gather" begin
include("gather.jl")
end

@testset "Scatter" begin
include("scatter.jl")
end

@testset "Utilities" begin
include("utils.jl")
end