@@ -716,3 +716,151 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
716
716
function Base. show (io:: IO , m:: Embedding )
717
717
print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
718
718
end
719
+
720
+
721
+ """
722
+ _splitat(data::AbstractVector, at::AbstractVector{Int})
723
+
724
+ Partitions `data` into a vector of views.
725
+
726
+ Each index `i in at` specifies that a view starts with `data[i]`.
727
+ These indices must be strictly increasing, and start at `1`.
728
+ The resulting views do not overlap, and are never empty.
729
+ The last view always ends with `data[end]`.
730
+
731
+ ### Example
732
+ ```jldoctest
733
+ julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13])
734
+ 4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}:
735
+ ['A', 'B']
736
+ ['C']
737
+ ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
738
+ ['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
739
+ ```
740
+ """
741
+ function _splitat (data:: AbstractVector , at:: AbstractVector{<:Integer} )
742
+ at[begin ] == firstindex (data) || throw (ArgumentError (" The first element in `at` must be 1." ))
743
+ at[end ] <= lastindex (data) || throw (ArgumentError (" The last element in `at` must be at most the length of `data`." ))
744
+ issorted (at, lt = <= ) || throw (ArgumentError (" `at` must be monotonically increasing with no duplicates." ))
745
+ iplus = vcat (at, lastindex (data)+ 1 )
746
+ return [view (data, iplus[n]: (iplus[n+ 1 ]- 1 )) for n in eachindex (at)]
747
+ end
748
+
749
+ """
750
+ EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)
751
+
752
+ A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`.
753
+ Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index,
754
+ it always acts a vector of indices which it calls a "bag".
755
+ Their individual embedding vectors are reduced to one, using `mean` or some other function.
756
+
757
+ Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several:
758
+
759
+ * Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors.
760
+ More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`.
761
+
762
+ * Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension.
763
+ Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`.
764
+ This method is more efficient, but requires that all "bags" have the same length.
765
+
766
+ * A vector of "bags" may also be produced by splitting a vector of indices at specified points.
767
+ For this case the layer takes two inputs, both vectors of integers. See details below.
768
+
769
+ The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these,
770
+ or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below.
771
+
772
+ # Examples
773
+ ```jldoctest
774
+ julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors:
775
+
776
+ julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100))
777
+ EmbeddingBag(26 => 3) # 78 parameters
778
+
779
+ julia> eb([2]) # one bag of 1 item
780
+ 3-element Vector{Float32}:
781
+ 0.0
782
+ 100.0
783
+ 0.0
784
+
785
+ julia> eb([3,3,1]) # one bag of 3 items, one mean embedding
786
+ 3-element Vector{Float32}:
787
+ 33.333332
788
+ 0.0
789
+ 66.666664
790
+
791
+ julia> eb([[3,1,3], [2,1]]) # two bags
792
+ 3×2 Matrix{Float32}:
793
+ 33.3333 50.0
794
+ 0.0 50.0
795
+ 66.6667 0.0
796
+
797
+ julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4])
798
+ 3×4 Matrix{Float32}:
799
+ 100.0 50.0 50.0 50.0
800
+ 0.0 50.0 0.0 0.0
801
+ 0.0 0.0 50.0 0.0
802
+
803
+ julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items
804
+ (3, 5, 5)
805
+ ```
806
+
807
+ Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`)
808
+ and a vector `at` stating where to split that up into "bags".
809
+ The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on,
810
+ with no overlaps and nothing left out (thus it requires `at[1]==1`).
811
+
812
+ ```jldoctest
813
+ julia> data = [11, 1, 12, 2, 13, 3, 14];
814
+
815
+ julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end]
816
+ [[11, 1, 12], [2, 13, 3, 14]]
817
+
818
+ julia> eb(data, [1, 4]) # two bags, of 3 and 4 items
819
+ 3×2 Matrix{Float32}:
820
+ 33.3333 0.0
821
+ 0.0 25.0
822
+ 0.0 25.0
823
+ ```
824
+
825
+ Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch).
826
+
827
+ ```jldoctest
828
+ julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items
829
+ 3-element Vector{Float32}:
830
+ 33.333332
831
+ 66.666664
832
+ 0.0
833
+
834
+ julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags
835
+ 3×2 Matrix{Float32}:
836
+ 33.3333 0.0
837
+ 66.6667 0.0
838
+ 0.0 100.0
839
+ ```
840
+ """
841
+ struct EmbeddingBag{F, W<: AbstractMatrix }
842
+ weight:: W
843
+ reduction:: F
844
+ end
845
+
846
+ @functor EmbeddingBag
847
+
848
+ EmbeddingBag ((in, out):: Pair{<:Integer, <:Integer} , reduction:: Function = mean; init = randn32) = EmbeddingBag (init (out, in), reduction)
849
+ EmbeddingBag (weight:: AbstractMatrix ) = EmbeddingBag (weight, mean)
850
+
851
+ (m:: EmbeddingBag )(data:: AbstractVector , at:: AbstractVector ) = m (_splitat (data, at))
852
+ (m:: EmbeddingBag )(inds:: AbstractArray{<:Integer} ) = dropdims (m. reduction (Embedding (m. weight)(inds), dims= 2 ), dims= 2 )
853
+ (m:: EmbeddingBag )(ind:: Integer ) = error (" EmbeddingBag expects an array of indices, not just one" )
854
+
855
+ (m:: EmbeddingBag )(hot:: AbstractArray{Bool} ) = dropdims (m. reduction (Embedding (m. weight)(hot), dims= 2 ), dims= 2 )
856
+ (m:: EmbeddingBag )(hot:: AbstractVector{Bool} ) = error (" EmbeddingBag not defined for a one-hot vector" )
857
+
858
+ # These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)
859
+ (m:: EmbeddingBag )(bags:: AbstractVector{<:AbstractVector} ) = reduce (hcat, m .(bags))
860
+ (m:: EmbeddingBag )(bags:: AbstractArray{<:AbstractVector} ) = reshape (m (vec (bags)), :, size (bags)... )
861
+
862
+ (m:: EmbeddingBag )(bags:: AbstractArray{<:AbstractMatrix{Bool}} ) = reshape (reduce (hcat, m .(vec (bags))), :, size (bags)... )
863
+
864
+ function Base. show (io:: IO , m:: EmbeddingBag )
865
+ print (io, " EmbeddingBag(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
866
+ end
0 commit comments