Skip to content

Commit dfea43c

Browse files
mcognettaCarloLucibellodarsnackmcabbott
authored
Add EmbeddingBag (#2031)
* embedding bag * doc fix * Apply suggestions from code review Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> * Remove references to `Statistics` Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`. * non mutating bag and onehot changes * better docs and todo * input/offset docs * doctest * Apply suggestions from code review Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * reduce docs * broadcast to map * remove extra doc example line * add _splitat * rename input/offset * minor docs * Apply suggestions from code review * Update test/layers/basic.jl * Update test/layers/basic.jl * Update test/layers/basic.jl * typo * docstring * Apply suggestions from code review --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent ccf87bb commit dfea43c

File tree

4 files changed

+225
-1
lines changed

4 files changed

+225
-1
lines changed

docs/src/models/layers.md

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ These layers accept an index, and return a vector (or several indices, and sever
9191

9292
```@docs
9393
Flux.Embedding
94+
Flux.EmbeddingBag
9495
```
9596

9697
## [Dataflow Layers, or Containers](@id man-dataflow-layers)

src/layers/basic.jl

+148
Original file line numberDiff line numberDiff line change
@@ -716,3 +716,151 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
716716
function Base.show(io::IO, m::Embedding)
717717
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
718718
end
719+
720+
721+
"""
722+
_splitat(data::AbstractVector, at::AbstractVector{Int})
723+
724+
Partitions `data` into a vector of views.
725+
726+
Each index `i in at` specifies that a view starts with `data[i]`.
727+
These indices must be strictly increasing, and start at `1`.
728+
The resulting views do not overlap, and are never empty.
729+
The last view always ends with `data[end]`.
730+
731+
### Example
732+
```jldoctest
733+
julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13])
734+
4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}:
735+
['A', 'B']
736+
['C']
737+
['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
738+
['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
739+
```
740+
"""
741+
function _splitat(data::AbstractVector, at::AbstractVector{<:Integer})
742+
at[begin] == firstindex(data) || throw(ArgumentError("The first element in `at` must be 1."))
743+
at[end] <= lastindex(data) || throw(ArgumentError("The last element in `at` must be at most the length of `data`."))
744+
issorted(at, lt = <=) || throw(ArgumentError("`at` must be monotonically increasing with no duplicates."))
745+
iplus = vcat(at, lastindex(data)+1)
746+
return [view(data, iplus[n]:(iplus[n+1]-1)) for n in eachindex(at)]
747+
end
748+
749+
"""
750+
EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)
751+
752+
A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`.
753+
Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index,
754+
it always acts a vector of indices which it calls a "bag".
755+
Their individual embedding vectors are reduced to one, using `mean` or some other function.
756+
757+
Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several:
758+
759+
* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors.
760+
More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`.
761+
762+
* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension.
763+
Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`.
764+
This method is more efficient, but requires that all "bags" have the same length.
765+
766+
* A vector of "bags" may also be produced by splitting a vector of indices at specified points.
767+
For this case the layer takes two inputs, both vectors of integers. See details below.
768+
769+
The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these,
770+
or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below.
771+
772+
# Examples
773+
```jldoctest
774+
julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors:
775+
776+
julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100))
777+
EmbeddingBag(26 => 3) # 78 parameters
778+
779+
julia> eb([2]) # one bag of 1 item
780+
3-element Vector{Float32}:
781+
0.0
782+
100.0
783+
0.0
784+
785+
julia> eb([3,3,1]) # one bag of 3 items, one mean embedding
786+
3-element Vector{Float32}:
787+
33.333332
788+
0.0
789+
66.666664
790+
791+
julia> eb([[3,1,3], [2,1]]) # two bags
792+
3×2 Matrix{Float32}:
793+
33.3333 50.0
794+
0.0 50.0
795+
66.6667 0.0
796+
797+
julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4])
798+
3×4 Matrix{Float32}:
799+
100.0 50.0 50.0 50.0
800+
0.0 50.0 0.0 0.0
801+
0.0 0.0 50.0 0.0
802+
803+
julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items
804+
(3, 5, 5)
805+
```
806+
807+
Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`)
808+
and a vector `at` stating where to split that up into "bags".
809+
The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on,
810+
with no overlaps and nothing left out (thus it requires `at[1]==1`).
811+
812+
```jldoctest
813+
julia> data = [11, 1, 12, 2, 13, 3, 14];
814+
815+
julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end]
816+
[[11, 1, 12], [2, 13, 3, 14]]
817+
818+
julia> eb(data, [1, 4]) # two bags, of 3 and 4 items
819+
3×2 Matrix{Float32}:
820+
33.3333 0.0
821+
0.0 25.0
822+
0.0 25.0
823+
```
824+
825+
Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch).
826+
827+
```jldoctest
828+
julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items
829+
3-element Vector{Float32}:
830+
33.333332
831+
66.666664
832+
0.0
833+
834+
julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags
835+
3×2 Matrix{Float32}:
836+
33.3333 0.0
837+
66.6667 0.0
838+
0.0 100.0
839+
```
840+
"""
841+
struct EmbeddingBag{F, W<:AbstractMatrix}
842+
weight::W
843+
reduction::F
844+
end
845+
846+
@functor EmbeddingBag
847+
848+
EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
849+
EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean)
850+
851+
(m::EmbeddingBag)(data::AbstractVector, at::AbstractVector) = m(_splitat(data, at))
852+
(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2)
853+
(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")
854+
855+
(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2)
856+
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector")
857+
858+
# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)
859+
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
860+
(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...)
861+
862+
(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...)
863+
864+
function Base.show(io::IO, m::EmbeddingBag)
865+
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
866+
end

src/layers/show.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
5959
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)
6060

6161
for T in [
62-
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
62+
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
6363
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
6464
]
6565
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)

test/layers/basic.jl

+75
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,81 @@ import Flux: activations
338338
y3 = m(x3)
339339
@test size(y3) == (embed_size, 3, 4)
340340
end
341+
342+
@testset "EmbeddingBag" begin
343+
344+
# test _splitat
345+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
346+
offsets_good = [1, 3, 6]
347+
offsets_each = [1,2,3,4,5,6,7,8,9]
348+
offsets_just_one = [1]
349+
offsets_all_but_last = [1, 9]
350+
351+
@test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
352+
@test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
353+
@test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]]
354+
@test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]]
355+
356+
offsets_non_monotonic = [1, 2, 2, 5]
357+
offsets_non_sorted = [1, 5, 2]
358+
offsets_non_one = [2, 3, 5]
359+
offsets_too_large = [1, 5, 11]
360+
361+
@test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic)
362+
@test_throws ArgumentError Flux._splitat(data, offsets_non_sorted)
363+
@test_throws ArgumentError Flux._splitat(data, offsets_non_one)
364+
@test_throws ArgumentError Flux._splitat(data, offsets_too_large)
365+
366+
@testset for reduction in [sum, Statistics.mean, maximum]
367+
vocab_size, embed_size = 10, 4
368+
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
369+
emb = Flux.Embedding(emb_bag.weight)
370+
@test size(emb_bag.weight) == (embed_size, vocab_size)
371+
@test_throws ErrorException emb_bag(2)
372+
373+
# single bag (input as a vector)
374+
x = rand(1:vocab_size, 3)
375+
y = emb_bag(x)
376+
z = vec(reduction(emb(x), dims=2))
377+
@test y isa Vector{Float32}
378+
@test y z
379+
380+
# PyTorch style `input`/`offset` bagging
381+
@test emb_bag([1,3,2,4,5,7], [1,3,5]) emb_bag([[1,3], [2,4], [5,7]])
382+
@test emb_bag([1,3,2,4,5,7], [1,3,5]) emb_bag([1 2 5; 3 4 7])
383+
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4])
384+
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12])
385+
386+
# docstring example
387+
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])
388+
389+
# multiple bags (input as a vector of vectors)
390+
x = [rand(1:vocab_size, 3) for _ in 1:4]
391+
y = emb_bag(x)
392+
z = reduce(hcat, reduction.(emb.(x), dims=2))
393+
@test y isa Matrix{Float32}
394+
@test y z
395+
396+
# multiple bags (input as a matrix)
397+
x = rand(1:vocab_size, (3, 5))
398+
xvec = collect(eachcol(x))
399+
y = emb_bag(x)
400+
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
401+
@test y emb_bag(xvec)
402+
@test y z
403+
404+
# a one-hot matrix is a bag, but a one-hot vector is not.
405+
@test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size))
406+
407+
i2 = rand(1:vocab_size, 3)
408+
x2 = Flux.OneHotMatrix(i2, vocab_size)
409+
y2 = emb_bag(x2)
410+
z2 = emb(i2)
411+
@test y2 isa Vector{Float32}
412+
@test y2 vec(reduction(z2, dims=2))
413+
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
414+
end
415+
end
341416
end
342417

343418
@testset "second derivatives" begin

0 commit comments

Comments
 (0)