diff --git a/src/data/collection.jl b/src/data/collection.jl index c794b24..26c6bab 100644 --- a/src/data/collection.jl +++ b/src/data/collection.jl @@ -1,5 +1,34 @@ -# for now, we load collections in memory. -# will be good to implement on-disk data structures too. +# TODO: implement on-disk collections, and the case where pids are not necessarily sorted and can be arbitrary +""" + Collection(path::String) + +A wrapper around a collection of documents, which stores the underlying collection as a `Vector{String}`. + +# Arguments + +- `path::String`: A path to the document dataset. It is assumed that `path` refers to a CSV file. Each line of the + the CSV file should be of the form `pid \\t document`, where `pid` is the integer index of the document. `pid`s should be in the range ``[1, N]``, where ``N`` is the number of documents, and should be sorted. + +# Examples + +Here's an example which loads a small subset of the LoTTe dataset defined in `short_collections.tsv` (see the `examples` folder in the package). + +```julia-repl +julia> using ColBERT; + +julia> dataroot = "downloads/lotte"; + +julia> dataset = "lifestyle"; + +julia> datasplit = "dev"; + +julia> path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") +"downloads/lotte/lifestyle/dev/short_collection.tsv" + +julia> collection = Collection(path) +Collection at downloads/lotte/lifestyle/dev/short_collection.tsv with 10 passages. +``` +""" struct Collection path::String data::Vector{String} @@ -12,10 +41,52 @@ function Collection(path::String) Collection(path, file.text) end +""" + get_chunksize(collection::Collection, nranks::Int) + +Determine the size of chunks used to store the index, based on the size of the `collection` and the number of available GPUs. + +# Arguments + +- `collection::Collection`: The underlying collection of documents. +- `nranks::Int`: Number of available GPUs to compute the index. At this point, the package only supports `nranks = 1`. + +# Examples + +Continuing from the example from the [`Collection`](@ref) constructor: + +```julia-repl +julia> get_chunksize(collection, 1) +11 +``` +""" function get_chunksize(collection::Collection, nranks::Int) Int(min(25000, 1 + floor(length(collection.data) / nranks))) end +""" + enumerate_batches(collection::Collection; [chunksize, nranks]) + +Batch the `collection` into chunks containing tuples of the form `(chunk_idx, offset, passages)`, where `chunk_idx` is the index of the chunk, `offset` is the index of the first passsage in the chunk, and `passages` is a `Vector{String}` containing the passages in the chunk. + +# Arguments + +- `collection::Collection`: The collection to batch. +- `chunksize::Union{Int, Missing}`: The chunksize to use to batch the collection. Default `missing`. If this is `missing`, then `chunksize` is determined using [`get_chunksize`](@ref) based on the `collection` and `nranks`. +- `nranks::Union{Int, Missing}`: The number of available GPUs. Default `missing`. Currently the package only supports `nranks = 1`. + +The `collection` is batched into chunks of uniform size (with the last chunk potentially having a smaller size). + +# Examples + +Continuing from the example in the [`Collection`](@ref) constructor. + +```julia-repl +julia> enumerate_batches(collection; nranks = 1); + +julia> enumerate_batches(collection; chunksize = 3); +``` +""" function enumerate_batches( collection::Collection; chunksize::Union{Int, Missing} = missing, nranks::Union{Int, Missing} = missing) @@ -42,3 +113,7 @@ function enumerate_batches( end batches end + +function Base.show(io::IO, collection::Collection) + print(io, "Collection at $(collection.path) with $(length(collection.data)) passages.") +end diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 481f900..d73cc24 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -1,5 +1,24 @@ using .ColBERT: ColBERTConfig +""" + ResidualCodec(config::ColBERTConfig, centroids::Matrix{Float64}, avg_residual::Float64, bucket_cutoffs::Vector{Float64}, bucket_weights::Vector{Float64}) + +A struct that represents a compressor for ColBERT embeddings. + +It stores information about the configuration of the model, the centroids used to quantize the residuals, the average residual value, and the cutoffs and weights used to determine which buckets each residual belongs to. + +# Arguments + +- `config`: A [`ColBERTConfig`](@ref), representing all configuration parameters related to various ColBERT components. +- `centroids`: A matrix of centroids used to quantize the residuals. Has shape `(D, N)`, where `D` is the embedding dimension and `N` is the number of clusters. +- `avg_residual`: The average residual value. +- `bucket_cutoffs`: A vector of cutoff values used to determine which buckets each residual belongs to. +- `bucket_weights`: A vector of weights used to determine the importance of each bucket. + +# Returns + +A `ResidualCodec` object. +""" mutable struct ResidualCodec config::ColBERTConfig centroids::Matrix{Float64} @@ -8,6 +27,21 @@ mutable struct ResidualCodec bucket_weights::Vector{Float64} end +""" + compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64}) + +Compresses a matrix of embeddings into a vector of codes using the given [`ResidualCodec`](@ref), where the code for each embedding is its nearest centroid ID. + +# Arguments + +- `codec`: The [`ResidualCodec`](@ref) used to compress the embeddings. +- `embs`: The matrix of embeddings to be compressed. + +# Returns + +A vector of codes, where each code corresponds to the nearest centroid ID for the embedding. +``` +""" function compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64}) codes = [] @@ -20,9 +54,24 @@ function compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64}) offset += bsize end + @assert length(codes) == size(embs)[2] codes end +""" + binarize(codec::ResidualCodec, residuals::Matrix{Float64}) + +Convert a matrix of residual vectors into a matrix of integer residual vector using `nbits` bits (specified by the underlying `config`). + +# Arguments + +- `codec`: A [`ResidualCodec`](@ref) object containing the compression information. +- `residuals`: The matrix of residuals to be converted. + +# Returns + +A matrix of compressed integer residual vectors. +""" function binarize(codec::ResidualCodec, residuals::Matrix{Float64}) dim = codec.config.doc_settings.dim nbits = codec.config.indexing_settings.nbits @@ -46,6 +95,22 @@ function binarize(codec::ResidualCodec, residuals::Matrix{Float64}) residuals_packed = reshape(residuals_packed, (Int(dim / 8) * nbits, num_embeddings)) # reshape back to get compressions for each embedding end +""" + compress(codec::ResidualCodec, embs::Matrix{Float64}) + +Compress a matrix of embeddings into a compact representation using the specified [`ResidualCodec`](@ref). + +All embeddings are compressed to their nearest centroid IDs and their quantized residual vectors (where the quantization is done in `nbits` bits, specified by the `config` of `codec`). If `emb` denotes an embedding and `centroid` is is nearest centroid, the residual vector is defined to be `emb - centroid`. + +# Arguments + +- `codec`: A [`ResidualCodec`](@ref) object containing the centroids and other parameters for the compression algorithm. +- `embs`: The input embeddings to be compressed. + +# Returns + +A tuple containing a vector of codes and the compressed residuals matrix. +""" function compress(codec::ResidualCodec, embs::Matrix{Float64}) codes, residuals = Vector{Int}(), Vector{Matrix{UInt8}}() @@ -65,6 +130,21 @@ function compress(codec::ResidualCodec, embs::Matrix{Float64}) codes, residuals end +""" + load_codes(codec::ResidualCodec, chunk_idx::Int) + +Load the codes from disk for a given chunk index. The codes are stored in the file `.codes.jld2` located inside the +`index_path` provided by the configuration. + +# Arguments + +- `codec`: The [`ResidualCodec`](@ref) object containing the compression information. +- `chunk_idx`: The chunk index for which the codes should be loaded. + +# Returns + +A vector of codes for the specified chunk. +""" function load_codes(codec::ResidualCodec, chunk_idx::Int) codes_path = joinpath(codec.config.indexing_settings.index_path, "$(chunk_idx).codes.jld2") codes = load(codes_path, "codes") diff --git a/src/indexing/collection_encoder.jl b/src/indexing/collection_encoder.jl index 733bebf..add9ea7 100644 --- a/src/indexing/collection_encoder.jl +++ b/src/indexing/collection_encoder.jl @@ -1,10 +1,44 @@ using ..ColBERT: ColBERTConfig +""" + CollectionEncoder(config::ColBERTConfig, checkpoint::Checkpoint) + +Structure to represent an encoder used to encode document passages to their corresponding embeddings. + +# Arguments + +- `config`: The underlying [`ColBERTConfig`](@ref). +- `checkpoint`: The [`Checkpoint`](@ref) used by the model. + +# Returns + +A [`CollectionEncoder`](@ref). + +""" struct CollectionEncoder config::ColBERTConfig checkpoint::Checkpoint end +""" + encode_passages(encoder::CollectionEncoder, passages::Vector{String}) + +Encode a list of passages using `encoder`. + +The given `passages` are run through the underlying BERT model and the linear layer to generate the embeddings, after doing relevant document-specific preprocessing. See [`docFromText`](@ref) for more details. + +# Arguments + +- `encoder`: The encoder used to encode the passages. +- `passages`: A list of strings representing the passages to be encoded. + +# Returns + +A tuple `embs, doclens` where: + +- `embs::Matrix{Float64}`: The full embedding matrix. Of shape `(D, N)`, where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages. +- `doclens::Vector{Int}`: A vector of document lengths for each passage, i.e the total number of attended tokens for each document passage. +""" function encode_passages(encoder::CollectionEncoder, passages::Vector{String}) @info "Encoding $(length(passages)) passages." diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 1eb43eb..29c26f8 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -1,5 +1,19 @@ using .ColBERT: ColBERTConfig, CollectionEncoder, ResidualCodec +""" + CollectionIndexer(config::ColBERTConfig, encoder::CollectionEncoder, saver::IndexSaver) + +Structure which performs all the index-building operations, including sampling initial centroids, clustering, computing document embeddings, compressing and building the `ivf`. + +# Arguments +- `config`: The [`ColBERTConfig`](@ref) used to build the model. +- `encoder`: The [`CollectionEncoder`](@ref) to be used for encoding documents. +- `saver`: The [`IndexSaver`](@ref), responsible for saving the index to disk. + +# Returns + +A [`CollectionIndexer`](@ref) object, containing all indexing-related information. See the [`setup`](@ref), [`train`](@ref), [`index`](@ref) and [`finalize`](@ref) functions for building the index. +""" mutable struct CollectionIndexer config::ColBERTConfig encoder::CollectionEncoder @@ -35,6 +49,17 @@ function CollectionIndexer(config::ColBERTConfig, encoder::CollectionEncoder, sa ) end +""" + _sample_pids(indexer::CollectionIndexer) + +Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering algorithm. + +# Arguments +- `indexer`: The collection indexer object containing the collection of passages to be indexed. + +# Returns +A `Set` of `Int`s containing the sampled PIDs. +""" function _sample_pids(indexer::CollectionIndexer) num_passages = length(indexer.config.resource_settings.collection.data) typical_doclen = 120 @@ -46,6 +71,23 @@ function _sample_pids(indexer::CollectionIndexer) sampled_pids end +""" + _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) + +Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref), compute the average document length using the embeddings, and save the sampled embeddings to disk. + +The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the total number of embeddings over all documents. + +Sample the passages with `pid` in `sampled_pids` from the `collection` and compute the average passage length. The function returns a tuple containing the embedded passages and the average passage length. + +# Arguments +- `indexer`: An instance of `CollectionIndexer`. +- `sampled_pids`: Set of PIDs sampled by [`_sample_pids`](@ref). + +# Returns + +The average document length (i.e number of attended tokens) computed from the sampled documents. +""" function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) # collect all passages with pids in sampled_pids collection = indexer.config.resource_settings.collection @@ -53,6 +95,9 @@ function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) local_sample = collection.data[sorted_sampled_pids] local_sample_embs, local_sample_doclens = encode_passages(indexer.encoder, local_sample) + @debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)" + @assert size(local_sample_embs)[2] == sum(local_sample_doclens) + indexer.num_sample_embs = size(local_sample_embs)[2] indexer.avg_doclen_est = length(local_sample_doclens) > 0 ? sum(local_sample_doclens) / length(local_sample_doclens) : 0 @@ -64,6 +109,17 @@ function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) indexer.avg_doclen_est end +""" + _save_plan(indexer::CollectionIndexer) + +Save the indexing plan to a JSON file. + +Information about the number of chunks, number of clusters, estimated number of embeddings over all documents and the estimated average document length is saved to a file named `plan.json`, with directory specified by the indexing directory. + +# Arguments + +- `indexer`: The `CollectionIndexer` object that contains the index plan to be saved. +""" function _save_plan(indexer::CollectionIndexer) @info "Saving the index plan to $(indexer.plan_path)." # TODO: export the config as json as well @@ -80,6 +136,16 @@ function _save_plan(indexer::CollectionIndexer) end end +""" + setup(indexer::CollectionIndexer) + +Initialize `indexer` by computing some indexing-specific estimates and save the indexing plan to disk. + +The number of chunks into which the document embeddings will be stored (`indexer.num_chunks`) is simply computed using the number of documents and the size of a chunk obtained from [`get_chunksize`](@ref). A bunch of pids used for initializing the centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the average document lengths and the estimated number of embeddings which will be computed across all documents. Finally, the number of clusters (`indexer.num_partitions`) to be used for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, and the indexing plan is saved to `plan.json` (see [`_save_plan`](@ref)) in the indexing directory. + +# Arguments +- `indexer::CollectionIndexer`: The indexer to be initialized. +""" function setup(indexer::CollectionIndexer) collection = indexer.config.resource_settings.collection indexer.num_chunks = Int(ceil(length(collection.data) / get_chunksize(collection, indexer.config.run_settings.nranks))) @@ -99,10 +165,25 @@ function setup(indexer::CollectionIndexer) _save_plan(indexer) end +""" + _concatenate_and_split_sample(indexer::CollectionIndexer) + +Randomly shuffle and split the sampled embeddings. + +The sample embeddings saved by the [`setup`](@ref) function are loaded, shuffled randomly, and then split into a `sample` and a `sample_heldout` set, with `sample_heldout` containing a `0.05` fraction of the original sampled embeddings. + +# Arguments +- `indexer`: The [`CollectionIndexer`](@ref). + +# Returns + +The tuple `sample, sample_heldout`. +""" function _concatenate_and_split_sample(indexer::CollectionIndexer) # load the sample embeddings sample_path = joinpath(indexer.config.indexing_settings.index_path, "sample.jld2") sample = load(sample_path, "local_sample_embs") + @debug "Original sample shape: $(size(sample))" # randomly shuffle embeddings num_local_sample_embs = size(sample)[2] @@ -112,9 +193,26 @@ function _concatenate_and_split_sample(indexer::CollectionIndexer) heldout_fraction = 0.05 heldout_size = Int(floor(min(50000, heldout_fraction * num_local_sample_embs))) sample, sample_heldout = sample[:, 1:(num_local_sample_embs - heldout_size)], sample[:, num_local_sample_embs - heldout_size + 1:num_local_sample_embs] + + @debug "Split sample sizes: sample size: $(size(sample)), \t sample_heldout size: $(size(sample_heldout))" sample, sample_heldout end +""" + _compute_avg_residuals(indexer::CollectionIndexer, centroids::Matrix{Float64}, heldout::Matrix{Float64}) + +Compute the average residuals and other statistics of the held-out sample embeddings. + +# Arguments + +- `indexer`: The underlying [`CollectionIndexer`](@ref). +- `centroids`: A matrix containing the centroids of the computed using a ``k``-means clustering algorithm on the sampled embeddings. Has shape `(D, indexer.num_partitions)`, where `D` is the embedding dimension (`128`) and `indexer.num_partitions` is the number of clusters. +- `heldout`: A matrix containing the held-out embeddings, computed using [`_concatenate_and_split_sample`](@ref). + +# Returns + +A tuple `bucket_cutoffs, bucket_weights, avg_residual`. +""" function _compute_avg_residuals(indexer::CollectionIndexer, centroids::Matrix{Float64}, heldout::Matrix{Float64}) compressor = ResidualCodec(indexer.config, centroids, 0.0, Vector{Float64}(), Vector{Float64}()) codes = compress_into_codes(compressor, heldout) # get centroid codes @@ -137,9 +235,22 @@ function _compute_avg_residuals(indexer::CollectionIndexer, centroids::Matrix{Fl bucket_cutoffs, bucket_weights, mean(avg_residual) end +""" + train(indexer::CollectionIndexer) + +Train a [`CollectionIndexer`](@ref) by computing centroids using a ``k``-means clustering algorithn, and store the compression information on disk. + +Average residuals and other compression data is computed via the [`_compute_avg_residuals`](@ref) function, and the codec is saved on disk using [`save_codec`](@ref). + +# Arguments + +- `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) to be trained. +""" function train(indexer::CollectionIndexer) sample, heldout = _concatenate_and_split_sample(indexer) centroids = kmeans(sample, indexer.num_partitions, maxiter = indexer.config.indexing_settings.kmeans_niters, display = :iter).centers + @assert size(centroids)[2] == indexer.num_partitions + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals(indexer, centroids, heldout) @info "avg_residual = $(avg_residual)" @@ -148,6 +259,18 @@ function train(indexer::CollectionIndexer) save_codec(indexer.saver) end +""" + index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) + +Build the index using `indexer`. + +The documents are processed in batches of size `chunksize` (see [`enumerate_batches`](@ref)). Embeddings and document lengths are computed for each batch (see [`encode_passages`](@ref)), and they are saved to disk along with relevant metadata (see [`save_chunk`](@ref)). + +# Arguments + +- `indexer`: The [`CollectionIndexer`](@ref) used to build the index. +- `chunksize`: Size of a chunk into which the index is to be stored. +""" function index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) load_codec!(indexer.saver) # load the codec objects batches = enumerate_batches(indexer.config.resource_settings.collection, chunksize = chunksize, nranks = indexer.config.run_settings.nranks) @@ -160,6 +283,16 @@ function index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = miss end end +""" + finalize(indexer::CollectionIndexer) + +Finalize the indexing process by saving all files, collecting embedding ID offsets, building IVF, and updating metadata. + +See [`_check_all_files_are_saved`](@ref), [`_collect_embedding_id_offset`](@ref), [`_build_ivf`](@ref) and [`_update_metadata`](@ref) for more details. + +# Arguments +- `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) used to finalize the indexing process. +""" function finalize(indexer::CollectionIndexer) _check_all_files_are_saved(indexer) _collect_embedding_id_offset(indexer) diff --git a/src/indexing/index_saver.jl b/src/indexing/index_saver.jl index 2e80f4a..d994839 100644 --- a/src/indexing/index_saver.jl +++ b/src/indexing/index_saver.jl @@ -1,10 +1,31 @@ using .ColBERT: ResidualCodec +""" + IndexSaver(config::ColBERTConfig, codec::Union{Missing, ResidualCodec} = missing) + +A structure to load/save various indexing components. + +# Arguments + +- `config`: A [`ColBERTConfig`](@ref). +- `codec`: A codec to encode and decode the embeddings. +""" Base.@kwdef mutable struct IndexSaver config::ColBERTConfig codec::Union{Missing, ResidualCodec} = missing end +""" + load_codec!(saver::IndexSaver) + +Load a codec from disk into `saver`. + +The path of of the codec is inferred from the config stored in `saver`. + +# Arguments + +- `saver`: An [`IndexSaver`](@ref) into which the codec is to be loaded. +""" function load_codec!(saver::IndexSaver) index_path = saver.config.indexing_settings.index_path centroids = load(joinpath(index_path, "centroids.jld2"), "centroids") @@ -13,6 +34,22 @@ function load_codec!(saver::IndexSaver) saver.codec = ResidualCodec(saver.config, centroids, avg_residual, buckets["bucket_cutoffs"], buckets["bucket_weights"]) end +""" + save_codec(saver::IndexSaver) + +Save the codec used by the `saver` to disk. + +This will create three files in the directory specified by the indexing path: + - `centroids.jld2` containing the centroids. + - `avg_residual.jld2` containing the average residual. + - `buckets.jld2` containing the bucket cutoffs and weights. + +Also see [`train`](@ref). + +# Arguments + +- `saver::IndexSaver`: The index saver to use. +""" function save_codec(saver::IndexSaver) index_path = saver.config.indexing_settings.index_path centroids_path = joinpath(index_path, "centroids.jld2") @@ -31,10 +68,26 @@ function save_codec(saver::IndexSaver) ) end +""" + save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, embs::Matrix{Float64}, doclens::Vector{Int}) + +Save a single chunk of compressed embeddings and their relevant metadata to disk. + +The codes and compressed residuals for the chunk are saved in files named `.codec.jld2`. The document lengths are saved in a file named `doclens..jld2`. Relevant metadata, including number of documents in the chunk, number of embeddings and the passage offsets are saved in a file named `.metadata.json`. + +# Arguments + +- `saver`: The [`IndexSaver`](@ref) containing relevant information to save the chunk. +- `chunk_idx`: The index of the current chunk being saved. +- `offset`: The offset in the original document collection where this chunk starts. +- `embs`: The embeddings matrix for the current chunk. +- `doclens`: The document lengths vector for the current chunk. +""" function save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, embs::Matrix{Float64}, doclens::Vector{Int}) codes, residuals = compress(saver.codec, embs) path_prefix = joinpath(saver.config.indexing_settings.index_path, string(chunk_idx)) - + @assert length(codes) == size(embs)[2] + # saving the compressed embeddings codes_path = "$(path_prefix).codes.jld2" residuals_path = "$(path_prefix).residuals.jld2" @@ -62,6 +115,20 @@ function save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, embs::Matrix end end +""" + check_chunk_exists(saver::IndexSaver, chunk_idx::Int) + +Check if the index chunk exists for the given `chunk_idx`. + +# Arguments + +- `saver`: The `IndexSaver` object that contains the indexing settings. +- `chunk_idx`: The index of the chunk to check. + +# Returns + +A boolean indicating whether all relevant files for the chunk exist. +""" function check_chunk_exists(saver::IndexSaver, chunk_idx::Int) index_path = saver.config.indexing_settings.index_path path_prefix = joinpath(index_path, string(chunk_idx)) diff --git a/src/infra/config.jl b/src/infra/config.jl index 58fbadf..ea36e67 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -1,3 +1,75 @@ +""" + ColBERTConfig(run_settings::RunSettings, tokenizer_settings::TokenizerSettings, resource_settings::ResourceSettings, doc_settings::DocSettings, query_settings::QuerySettings, indexing_settings::IndexingSettings, search_settings::SearchSettings) + +Structure containing config for running and training various components. + +# Arguments + +- `run_settings`: Sets the [`RunSettings`](@ref). +- `tokenizer_settings`: Sets the [`TokenizerSettings`](@ref). +- `resource_settings`: Sets the [`ResourceSettings`](@ref). +- `doc_settings`: Sets the [`DocSettings`](@ref). +- `query_settings`: Sets the [`QuerySettings`](@ref). +- `indexing_settings`: Sets the [`IndexingSettings`](@ref). +- `search_settings`: Sets the [`SearchSettings`](@ref). + +# Returns + +A [`ColBERTConfig`](@ref) object. + +# Examples + +The relevant files for this example can be found in the `examples/` folder of the project root. + +```julia-repl +julia> dataroot = "downloads/lotte" + +julia> dataset = "lifestyle" + +julia> datasplit = "dev" + +julia> path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") + +julia> collection = Collection(path) + +julia> length(collection.data) + +julia> nbits = 2 # encode each dimension with 2 bits + +julia> doc_maxlen = 300 # truncate passages at 300 tokens + +julia> checkpoint = "colbert-ir/colbertv2.0" # the HF checkpoint + +julia> index_root = "experiments/notebook/indexes" + +julia> index_name = "short_\$(dataset).\$(datasplit).\$(nbits)bits" + +julia> index_path = joinpath(index_root, index_name) + +julia> config = ColBERTConfig( + RunSettings( + experiment="notebook", + ), + TokenizerSettings(), + ResourceSettings( + checkpoint=checkpoint, + collection=collection, + index_name=index_name, + ), + DocSettings( + doc_maxlen=doc_maxlen, + ), + QuerySettings(), + IndexingSettings( + index_path=index_path, + index_bsize=3, + nbits=nbits, + kmeans_niters=20, + ), + SearchSettings(), + ); +``` +""" Base.@kwdef struct ColBERTConfig run_settings::RunSettings tokenizer_settings::TokenizerSettings diff --git a/src/infra/settings.jl b/src/infra/settings.jl index 64669c2..b2954b8 100644 --- a/src/infra/settings.jl +++ b/src/infra/settings.jl @@ -1,5 +1,23 @@ using ..ColBERT: Collection +""" + RunSettings([root, experiment, index_root, name, rank, nranks]) + +Structure holding all the settings necessary to describe the run environment. + +# Arguments + +- `root`: The root directory for the run. Default is an `"experiments"` folder in the current working directory. +- `experiment`: The name of the run. Default is `"default"`. +- `index_root`: The root directory for storing index. For now, there is no need to specify this as it is determined by the indexing component. +- `name`: The name of the run. Default is the current date and time. +- `rank`: The index of the running GPU. Default is `0`. For now, the package only allows this to be `0`. +- `nranks`: The number of GPUs used in the run. Default is `1`. For now, the package only supports one GPU. + +# Returns + +A `RunSettings` object. +""" Base.@kwdef struct RunSettings root::String = joinpath(pwd(), "experiments") experiment::String = "default" @@ -9,6 +27,22 @@ Base.@kwdef struct RunSettings nranks::Int = 1 end +""" + TokenizerSettings([query_token_id, doc_token_id, query_token, doc_token]) + +Structure to represent settings for the tokenization of queries and documents. + +# Arguments + +- `query_token_id`: Unique identifier for query tokens (defaults to `[unused0]`). +- `doc_token_id`: Unique identifier for document tokens (defaults to `[unused1]`). +- `query_token`: Token used to represent a query token (defaults to `[Q]`). +- `doc_token`: Token used to represent a document token (defaults to `[D]`). + +# Returns + +A `TokenizerSettings` object. +""" Base.@kwdef struct TokenizerSettings query_token_id::String = "[unused0]" doc_token_id::String = "[unused1]" @@ -16,6 +50,22 @@ Base.@kwdef struct TokenizerSettings doc_token::String = "[D]" end +""" + ResourceSettings([checkpoint, collection, queries, index_name]) + +Structure to represent resource settings. + +# Arguments + +- `checkpoint`: The path to the HuggingFace checkpoint of the underlying ColBERT model. +- `collection`: The underlying collection of documents +- `queries`: The underlying collection of queries. +- `index_name`: The name of the index. + +# Returns + +A `ResourceSettings` object. +""" Base.@kwdef struct ResourceSettings checkpoint::Union{Nothing, String} = nothing collection::Union{Nothing, Collection} = nothing @@ -23,18 +73,63 @@ Base.@kwdef struct ResourceSettings index_name::Union{Nothing, String} = nothing end +""" + DocSettings([dim, doc_maxlen, mask_punctuation]) + +Structure that defines the settings used for generating document embeddings. + +# Arguments + +- `dim`: The dimension of the document embedding space. Default is 128. +- `doc_maxlen`: The maximum length of a document before it is trimmed to fit. Default is 220. +- `mask_punctuation`: Whether or not to mask punctuation characters tokens in the document. Default is true. + +# Returns + +A `DocSettings` object. +""" Base.@kwdef struct DocSettings dim::Int = 128 doc_maxlen::Int = 220 mask_punctuation::Bool = true end +""" + QuerySettings([query_maxlen, attend_to_mask_tokens, interaction]) + +A structure representing the query settings used by the ColBERT model. + +# Arguments + +- `query_maxlen`: The maximum length of queries after which they are trimmed. +- `attend_to_mask_tokens`: Whether or not to attend to mask tokens in the query. Default value is false. +- `interaction`: The type of interaction used to compute the scores for the queries. Default value is "colbert". + +# Returns + +A `QuerySettings` object. +""" Base.@kwdef struct QuerySettings query_maxlen::Int = 32 attend_to_mask_tokens::Bool = false interaction::String = "colbert" end +""" + IndexingSettings([index_path, index_bsize, nbits, kmeans_niters]) + +Structure containing settings for indexing. + +# Arguments +- `index_path`: Path to save the index files. +- `index_bsize::Int`: Batch size used for some parts of indexing. +- `nbits::Int`: Number of bits used to compress residuals. +- `kmeans_niters::Int`: Number of iterations used for k-means clustering. + +# Returns + +An `IndexingSettings` object. +""" Base.@kwdef struct IndexingSettings index_path::Union{Nothing, String} = nothing index_bsize::Int = 64 diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 5a89c87..eeaddd2 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -1,5 +1,86 @@ using ..ColBERT: DocTokenizer, ColBERTConfig +""" + BaseColBERT(; bert::Transformers.HuggingFace.HGFBertModel, linear::Transformers.Layers.Dense, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder) + +A struct representing the BERT model, linear layer, and the tokenizer used to compute embeddings for documents and queries. + +# Arguments +- `bert`: The pre-trained BERT model used to generate the embeddings. +- `linear`: The linear layer used to project the embeddings to a specific dimension. +- `tokenizer`: The tokenizer to used by the BERT model. + +# Returns + +A [`BaseColBERT`](@ref) object. + +# Examples + +The `config` in the below example is taken from the example in [`ColBERTConfig`](@ref). + +```julia-repl +julia> base_colbert = BaseColBERT(checkpoint, config); + +julia> base_colbert.bert +HGFBertModel( + Chain( + CompositeEmbedding( + token = Embed(768, 30522), # 23_440_896 parameters + position = ApplyEmbed(.+, FixedLenPositionEmbed(768, 512)), # 393_216 parameters + segment = ApplyEmbed(.+, Embed(768, 2), Transformers.HuggingFace.bert_ones_like), # 1_536 parameters + ), + DropoutLayer( + LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters + ), + ), + Transformer<12>( + PostNormTransformerBlock( + DropoutLayer( + SelfAttention( + MultiheadQKVAttenOp(head = 12, p = nothing), + Fork<3>(Dense(W = (768, 768), b = true)), # 1_771_776 parameters + Dense(W = (768, 768), b = true), # 590_592 parameters + ), + ), + LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters + DropoutLayer( + Chain( + Dense(σ = NNlib.gelu, W = (768, 3072), b = true), # 2_362_368 parameters + Dense(W = (3072, 768), b = true), # 2_360_064 parameters + ), + ), + LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters + ), + ), # Total: 192 arrays, 85_054_464 parameters, 324.477 MiB. + Branch{(:pooled,) = (:hidden_state,)}( + BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)), # 590_592 parameters + ), +) # Total: 199 arrays, 109_482_240 parameters, 417.664 MiB. + +julia> base_colbert.linear +Dense(W = (768, 128), b = true) # 98_432 parameters + +julia> base_colbert.tokenizer +BertTextEncoder( +├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)), +├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101), +├─ startsym = [CLS], +├─ endsym = [SEP], +├─ padsym = [PAD], +├─ trunc = 512, +└─ process = Pipelines: + ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source) + ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token) + ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]: Input[1]: [SEP]: (Input[2]: [SEP]:)...)(target.token) + ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token) + ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token) + ╰─ target[token] := TextEncodeBase.nested2batch(target.token) + ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) + ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment) + ╰─ target := (target.token, target.segment, target.attention_mask) + +``` +""" struct BaseColBERT bert::Transformers.HuggingFace.HGFBertModel linear::Transformers.Layers.Dense @@ -17,6 +98,62 @@ function BaseColBERT(checkpoint::String, config::ColBERTConfig) BaseColBERT(bert_model, linear, tokenizer) end +""" + Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, colbert_config::ColBERTConfig) + +A wrapper for [`BaseColBERT`](@ref), which includes a [`ColBERTConfig`](@ref) and tokenization-specific functions via the [`DocTokenizer`](@ref) type. + +If the config's [`DocSettings`](@ref) are configured to mask punctuations, then the `skiplist` property of the created [`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. + +# Arguments +- `model`: The [`BaseColBERT`](@ref) to be wrapped. +- `doc_tokenizer`: A [`DocTokenizer`](@ref) used for functions related to document tokenization. +- `colbert_config`: The underlying [`ColBERTConfig`](@ref). + +# Returns +The created [`Checkpoint`](@ref). + +# Examples + +Continuing from the example for [`BaseColBERT`](@ref): + +```julia-repl +julia> checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), config); + +julia> checkPoint.skiplist # by default, all punctuations +32-element Vector{Int64}: + 1000 + 1001 + 1002 + 1003 + 1004 + 1005 + 1006 + 1007 + 1008 + 1009 + 1010 + 1011 + 1012 + 1013 + ⋮ + 1028 + 1029 + 1030 + 1031 + 1032 + 1033 + 1034 + 1035 + 1036 + 1037 + 1064 + 1065 + 1066 + 1067 + +``` +""" struct Checkpoint model::BaseColBERT doc_tokenizer::DocTokenizer @@ -34,6 +171,46 @@ function Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, colbert_con Checkpoint(model, doc_tokenizer, colbert_config, skiplist) end +""" + mask_skiplist(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, integer_ids::AbstractArray, skiplist::Union{Missing, Vector{Int}}) + +Create a mask for the given `integer_ids`, based on the provided `skiplist`. +If the `skiplist` is not missing, then any token IDs in the list will be filtered out along with the padding token. +Otherwise, all tokens are included in the mask. + +# Arguments + +- `tokenizer`: The underlying tokenizer. +- `integer_ids`: An `Array` of token IDs for the documents. +- `skiplist`: A list of token IDs to skip in the mask. + +# Returns +An array of booleans indicating whether the corresponding token ID is included in the mask or not. The array has the same shape as `integer_ids`, i.e `(L, N)`, where `L` is the maximum length of any document in `integer_ids` and `N` is the number of documents. + +# Examples + +Continuing with the example for [`tensorize`](@ref) and the `skiplist` from the example in [`Checkpoint`](@ref). + +```julia-repl +julia> mask_skiplist(checkPoint.model.tokenizer, integer_ids, checkPoint.skiplist) +14×4 BitMatrix: + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 0 0 1 + 0 1 0 1 + 0 0 0 1 + 0 0 0 0 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + +``` +""" function mask_skiplist(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, integer_ids::AbstractArray, skiplist::Union{Missing, Vector{Int}}) if !ismissing(skiplist) filter = token_id -> !(token_id in skiplist) && token_id != TextEncodeBase.lookup(tokenizer.vocab, tokenizer.padsym) @@ -43,6 +220,47 @@ function mask_skiplist(tokenizer::Transformers.TextEncoders.AbstractTransformerT filter.(integer_ids) end +""" + doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::AbstractArray) + +Compute the hidden state of the BERT and linear layers of ColBERT. + +# Arguments + +- `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. +- `integer_ids`: An array of token IDs to be fed into the BERT model. +- `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. + +# Returns + +A tuple `D, mask`, where: + +- `D` is an array containing the normalized embeddings for each token in each document. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any document and `N` is the total number of documents. +- `mask` is an array containing attention masks for all documents, after masking out any tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` is the same as described above. + +# Examples + +Continuing from the example in [`tensorize`](@ref) and [`Checkpoint`](@ref): + +```julia-repl +julia> D, mask = doc(checkPoint, integer_ids, integer_mask); + +julia> mask +1×14×4 BitArray{3}: +[:, :, 1] = + 1 1 1 1 1 0 0 0 0 0 0 0 0 0 + +[:, :, 2] = + 1 1 1 1 0 1 0 0 0 0 0 0 0 0 + +[:, :, 3] = + 1 1 1 1 0 0 0 0 0 0 0 0 0 0 + +[:, :, 4] = + 1 1 1 1 1 1 1 0 1 1 1 1 1 1 + +``` +""" function doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::AbstractArray) D = checkpoint.model.bert((token=integer_ids, attention_mask=NeuralAttentionlib.GenericSequenceMask(integer_mask))).hidden_state D = checkpoint.model.linear(D) @@ -55,10 +273,85 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::A D, mask end +""" + docFromText(checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int}) + +Get ColBERT embeddings for `docs` using `checkpoint`. + +This function also applies ColBERT-style document pre-processing for each document in `docs`. + +# Arguments + +- `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. +- `docs`: A list of documents to get the embeddings for. +- `bsize`: A batch size for processing documents in batches. + +# Returns + +A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings across all documents in `docs`. + +# Examples + +Continuing from the example in [`Checkpoint`](@ref): + +```julia-repl +julia> docs = [ + "hello world", + "thank you!", + "a", + "this is some longer text, so length should be longer", +]; + +julia> embs, doclens = docFromText(checkPoint, docs, config.indexing_settings.index_bsize) +(Float32[0.07590997 0.00056472444 … -0.09958261 -0.03259005; 0.08413661 -0.016337946 … -0.061889287 -0.017708546; … ; -0.11584533 0.016651645 … 0.0073241345 0.09233974; 0.043868616 0.084660925 … -0.0294838 -0.08536169], [5 5 4 13]) + +julia> embs +128×27 Matrix{Float32}: + 0.07591 0.000564724 … -0.0811892 -0.0995826 -0.0325901 + 0.0841366 -0.0163379 -0.0118506 -0.0618893 -0.0177085 + -0.0301104 -0.0128125 0.0138397 -0.0573847 0.177861 + 0.0375673 0.216562 -0.110819 0.00425483 -0.00131543 + 0.0252677 0.151702 -0.0272065 0.0350983 -0.0381015 + 0.00608629 -0.0415363 … 0.122848 0.0747104 0.0836627 + -0.185256 -0.106582 0.0352982 -0.0405874 -0.064156 + -0.0816655 -0.142809 0.0565001 -0.134649 0.00380807 + 0.00471224 0.00444499 0.0112827 0.0253297 0.0665076 + -0.121564 -0.189994 0.0151938 -0.119054 -0.0980481 + 0.157599 0.0919844 … 0.0330667 0.0205288 0.0184296 + 0.0132481 -0.0430333 0.0404867 0.0575921 0.101702 + 0.0695787 0.0281928 -0.0378472 -0.053183 -0.123457 + -0.0933986 -0.0390347 0.0279156 0.0309749 0.00298161 + 0.0458561 0.0729707 0.103661 0.00905471 0.127777 + 0.00452597 0.05959 … 0.148845 0.0569492 0.293592 + ⋮ ⋱ ⋮ + 0.0510929 -0.138272 -0.00646483 -0.0171806 -0.0618908 + 0.128495 0.181198 -0.00408871 0.0274591 0.0343185 + -0.0961544 -0.0223997 0.0117907 -0.0813832 0.038232 + 0.0285498 0.0556695 … -0.0139291 -0.14533 -0.0176019 + 0.011212 -0.164717 0.071643 -0.0662124 0.164667 + -0.00178153 0.0600864 0.120243 0.0490749 0.0562548 + -0.0261783 0.0343851 0.0469064 0.040038 -0.0536367 + -0.0696538 -0.020624 0.0441996 0.0842775 0.0567261 + -0.0940356 -0.106123 … 0.00334512 0.00795235 -0.0439883 + 0.0567849 -0.0312434 -0.113022 0.0616158 -0.0738149 + -0.0143086 0.105833 -0.142671 -0.0430241 -0.0831739 + 0.044704 0.0783603 -0.0413787 0.0315282 -0.171445 + 0.129225 0.112544 0.120684 0.107231 0.119762 + 0.000207455 -0.124472 … -0.0930788 -0.0519733 0.0837618 + -0.115845 0.0166516 0.0577464 0.00732413 0.0923397 + 0.0438686 0.0846609 -0.0967041 -0.0294838 -0.0853617 + +julia> doclens +1×4 Matrix{Int64}: + 5 5 4 13 + +``` +""" function docFromText(checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int}) if ismissing(bsize) - integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) - doc(checkpoint, integer_ids, integer_mask) + # integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) + # doc(checkpoint, integer_ids, integer_mask) + error("Currently bsize cannot be missing!") else text_batches, reverse_indices = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) batches = [doc(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in text_batches] diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index f8e8f97..e4910e3 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -1,5 +1,19 @@ using ...ColBERT: ColBERTConfig +""" + DocTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, config::ColBERTConfig) + +Construct a `DocTokenizer` from a given tokenizer and configuration. The resulting structure supports functions to perform CoLBERT-style document operations on document texts. + +# Arguments + +- `tokenizer`: A tokenizer that has been trained on the BERT vocabulary. Fetched from HuggingFace. +- `config`: The underlying [`ColBERTConfig`](@ref). + +# Returns + +A `DocTokenizer` object. +""" struct DocTokenizer D_marker_token_id::Int config::ColBERTConfig @@ -10,16 +24,144 @@ function DocTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTe DocTokenizer(D_marker_token_id, config) end +""" + tensorize(doc_tokenizer::DocTokenizer, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) + +Convert a collection of documents to tensors in the ColBERT format. + +This function adds the document marker token at the beginning of each document and then converts the text data into integer IDs and masks using the `tokenizer`. The returned objects are determined by the `bsize` argument. More specifically: + +- If `bsize` is missing, then a tuple `integer_ids, integer_mask` is returned, where `integer_ids` is an `Array` of token IDs for the modified documents, and `integer_mask` is an `Array` of attention masks for each document. +- If `bsize` is not missing, then more optimizing operations are performed on the documents. First, the arrays of token IDs and attention masks are sorted by document lengths (this is for more efficient use of GPUs on the batches; see [`_sort_by_length`](@ref)), and a list `reverse_indices` is computed, which remembers the original order of the documents (to reorder them later). The arrays of token IDs and attention masks are then batched into batches of size `bsize` (see [`_split_into_batches`](@ref)). Finally, the batches along with the list of `reverse_indices` are returned. + +# Arguments + +- `doc_tokenizer`: An instance of the `DocTokenizer` type. This object contains information about the document marker token ID. +- `tokenizer`: The tokenizer which is used to convert text data into integer IDs. +- `batch_text`: A document texts that will be converted into tensors of token IDs. +- `bsize`: The size of the batches to split the `batch_text` into. Can also be `missing`. + +# Returns + +If `bsize` is `missing`, then a tuple is returned, which contains: + +- `integer_ids`: An `Array` of integer IDs representing the token IDs of the documents in the input collection. It has shape `(L, N)`, where `L` is the length of the largest document in `batch_text` (i.e the document with the largest number of tokens), and `N` is the number of documents in the batch. +- `integer_mask`: An `Array` of bits representing the attention mask for each document. It has shape `(L, N)`, the same as `integer_ids`. + +If `bsize` is not `missing`, then a tuple containing the following is returned: + +- `batches`: A `Vector` of tuples of arrays of token IDs and masks, sorted in the order of document lengths. Each array in each tuple has shape `(L, N)`, where `L` is the length of the largest document in `batch_text`, and `N` is the number of documents in the batch being considered. +- `reverse_indices`: A `Vector` containing the indices of the documents in their original order. + +# Examples + +```julia-repl +julia> base_colbert = BaseColBERT("colbert-ir/colbertv2.0", config); + +julia> tokenizer = base_colbert.tokenizer; + +julia> doc_tokenizer = DocTokenizer(tokenizer, config); + +julia> batch_text = [ + "hello world", + "thank you!", + "a", + "this is some longer text, so length should be longer", +]; + +julia> integer_ids, integer_mask = tensorize(doc_tokenizer, tokenizer, batch_text, missing); # no batching + +julia> integer_ids +14×4 reinterpret(Int32, ::Matrix{PrimitiveOneHot.OneHot{0x0000773a}}): + 102 102 102 102 + 3 3 3 3 + 7593 4068 1038 2024 + 2089 2018 103 2004 + 103 1000 1 2071 + 1 103 1 2937 + 1 1 1 3794 + 1 1 1 1011 + 1 1 1 2062 + 1 1 1 3092 + 1 1 1 2324 + 1 1 1 2023 + 1 1 1 2937 + 1 1 1 103 + +julia> integer_mask +14×4 Matrix{Bool}: + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 0 1 + 0 1 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + 0 0 0 1 + +julia> batch_text = [ + "hello world", + "thank you!", + "a", + "this is some longer text, so length should be longer", + "this is an even longer document. this is some longer text, so length should be longer", +]; + +julia> batches, reverse_indices = tensorize(doc_tokenizer, tokenizer, batch_text, 3) +2-element Vector{Tuple{AbstractArray, AbstractMatrix}}: + (Int32[102 102 102; 3 3 3; … ; 1 1 1; 1 1 1], Bool[1 1 1; 1 1 1; … ; 0 0 0; 0 0 0]) + (Int32[102 102; 3 3; … ; 1 2937; 1 103], Bool[1 1; 1 1; … ; 0 1; 0 1]) + +julia> batches[1][1] # this time they are sorted by length +21×3 Matrix{Int32}: + 102 102 102 + 3 3 3 + 1038 7593 4068 + 103 2089 2018 + 1 103 1000 + 1 1 103 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + +julia> reverse_indices # the original order +5-element Vector{Int64}: + 2 + 3 + 1 + 4 + 5 + +``` +""" function tensorize(doc_tokenizer::DocTokenizer, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) # placeholder for [D] marker token batch_text = [". " * doc for doc in batch_text] - vocabsize = length(tokenizer.vocab.list) # getting the integer ids and masks encoded_text = Transformers.TextEncoders.encode(tokenizer, batch_text) ids, mask = encoded_text.token, encoded_text.attention_mask integer_ids = reinterpret(Int32, ids) integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] + @assert isequal(size(integer_ids), size(integer_mask)) # adding the [D] marker token ID integer_ids[2, :] .= doc_tokenizer.D_marker_token_id @@ -30,18 +172,29 @@ function tensorize(doc_tokenizer::DocTokenizer, tokenizer::Transformers.TextEnco # we sort passages by length to do batch packing for more efficient use of the GPU integer_ids, integer_mask, reverse_indices = _sort_by_length(integer_ids, integer_mask, bsize) batches = _split_into_batches(integer_ids, integer_mask, bsize) + @assert length(reverse_indices) == length(batch_text) batches, reverse_indices end end """ - _sort_by_length(ids::AbstractMatrix, mask::AbstractMatrix, bsize::Int) + _sort_by_length(integer_ids::AbstractMatrix, integer_mask::AbstractMatrix, bsize::Int) -Sort sentences by number of attended tokens, if the number of sentences is larger than bsize. If the number of passages (first dimension of `ids`) is atmost -than `bsize`, the `ids`, `mask`, and a list `Vector(1:size(ids)[1])` is returned as a three-tuple. Otherwise, -the passages are first sorted by the number of attended tokens (figured out from `mask`), and then the the sorted arrays -`ids` and `mask` are returned, along with a reversed list of indices, i.e a mapping from passages to their indice in the sorted list. +Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. + +# Arguments + +- `integer_ids`: The token IDs of documents to be sorted. +- `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). +- `bsize`: The size of batches to be considered. + +# Returns + +Depending upon `bsize`, the following are returned: + +- If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the `integer_ids` and `integer_mask` are returned unchanged. +- If the number of documents is larger than `bsize`, then the passages are first sorted by the number of attended tokens (figured out from the `integer_mask`), and then the sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of `reverse_indices`, i.e a mapping from the documents to their indices in the original order. """ function _sort_by_length(integer_ids::AbstractMatrix, integer_mask::AbstractMatrix, bsize::Int) batch_size = size(integer_ids)[2] @@ -51,16 +204,25 @@ function _sort_by_length(integer_ids::AbstractMatrix, integer_mask::AbstractMatr end lengths = vec(sum(integer_mask; dims = 1)) # number of attended tokens in each passage - indices = sortperm(lengths) # get the indices which will sort lengths - reverse_indices = sortperm(indices) # invert the indices list + indices = sortperm(lengths) # get the indices which will sort lengths + reverse_indices = sortperm(indices) # invert the indices list integer_ids[:, indices], integer_mask[:, indices], reverse_indices end """ - _split_into_batches(integer_ids::AbstractArray, integer_mask::AbstractMatrix, bsize::Int)::Vector{Tuple{AbstractArray, AbstractMatrix, Int}} + _split_into_batches(integer_ids::AbstractArray, integer_mask::AbstractMatrix, bsize::Int) Split the given `integer_ids` and `integer_mask` into batches of size `bsize`. + +# Arguments + +- `integer_ids`: The array of token IDs to batch. +- `integer_mask`: The array of attention masks to batch. + +# Returns + +Batches of token IDs and attention masks, with each batch having size `bsize` (with the possibility of the last batch being smaller). """ function _split_into_batches(integer_ids::AbstractArray, integer_mask::AbstractMatrix, bsize::Int) batch_size = size(integer_ids)[2] @@ -70,12 +232,3 @@ function _split_into_batches(integer_ids::AbstractArray, integer_mask::AbstractM end batches end - -# tokenizer = base_colbert.tokenizer -# batch_text = [ -# "hello world", -# "thank you!", -# "a", -# "this is some longer text, so length should be longer", -# ] -# bsize = 2 diff --git a/src/utils/utils.jl b/src/utils/utils.jl index aaee0ae..a2c9068 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,3 +1,18 @@ +""" + batch(group::Vector, bsize::Int; [provide_offset::Bool = false]) + +Create batches of data from `group`. + +Each batch is a subvector of `group` with length equal to `bsize`. If `provide_offset` is true, each batch will be a tuple containing both the offset and the subvector, otherwise only the subvector will be returned. + +# Arguments +- `group::Vector`: The input vector from which to create batches. +- `bsize::Int`: The size of each batch. +- `provide_offset::Bool = false`: Whether to include the offset in the output batches. Defaults to `false`. + +# Returns +A vector of tuples, where each tuple contains an offset and a subvector, or just a vector containing subvectors, depending on the value of `provide_offset`. +""" function batch(group::Vector, bsize::Int; provide_offset::Bool = false) vtype = provide_offset ? Vector{Tuple{Int, typeof(group)}} :