Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Adding documentation for the indexing components. #3

Merged
merged 23 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a6b2de2
Adding docstrings for `Collection`s and their functions.
codetalker7 Jun 25, 2024
06ec6f4
Adding docstrings for objects in `residual.jl`.
codetalker7 Jul 2, 2024
2045f80
Adding documentation for config settings. Skipping `SearchSettings` for
codetalker7 Jul 3, 2024
5b8e201
Adding docstring for the `ColBERTConfig`.
codetalker7 Jul 4, 2024
59f7406
Adding some docstrings for `checkpoint.jl`.
codetalker7 Jul 4, 2024
8acd7ca
Adding docstrings for objects in `doc_tokenization.jl`.
codetalker7 Jul 7, 2024
e0dccde
Adding docstring for `utils.jl`.
codetalker7 Jul 7, 2024
645eb3b
Adding docstrings for the `CollectionEncoder` and `CollectionIndexer`
codetalker7 Jul 10, 2024
1e24cf4
Adding a bunch of docstrings related to the `setup` function. Also
codetalker7 Jul 11, 2024
885bc5b
Some minor fixes.
codetalker7 Jul 11, 2024
beb2b38
Adding docstrings for some functions related to `train`, and adding
codetalker7 Jul 11, 2024
de029a2
Adding shape of the centroids matrix in the docstring for `ResidualCo…
codetalker7 Jul 11, 2024
123f022
Adding an assert in `compress_into_codes`.
codetalker7 Jul 11, 2024
df1d475
Adding docstrings for functions in `index_saver.jl`.
codetalker7 Jul 11, 2024
b5c3f7c
Adding docstrings for `collection_encoder.jl`.
codetalker7 Jul 11, 2024
da88bc8
Adding more documentation for `tensorize`, and adding a simple assert.
codetalker7 Jul 12, 2024
f33eedf
More documentation for `BaseColBERT` and `Checkpoint`.
codetalker7 Jul 12, 2024
ddc2358
More documentation for `mask_skiplist`.
codetalker7 Jul 12, 2024
1d3d9ee
Added a sanity check for `tensorize`.
codetalker7 Jul 12, 2024
5d8c70d
Adding an example in the docstring for `ColBERTConfig`.
codetalker7 Jul 12, 2024
b2a07aa
Adding a reference to the config used in the docstring of `BaseColBERT`.
codetalker7 Jul 12, 2024
3f72132
Adding docstrings for `doc` and `docFromText`.
codetalker7 Jul 12, 2024
78436bc
Adding documentation for `index` and `finalize`.
codetalker7 Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions src/data/collection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# for now, we load collections in memory.
# will be good to implement on-disk data structures too.
# TODO: implement on-disk collections, and the case where pids are not necessarily sorted and can be arbitrary
"""
Collection(path::String)

A wrapper around a collection of documents, which stores the underlying collection as a `Vector{String}`.

# Arguments

- `path::String`: A path to the document dataset. It is assumed that `path` refers to a CSV file. Each line of the
the CSV file should be of the form `pid \\t document`, where `pid` is the integer index of the document. `pid`s should be in the range ``[1, N]``, where ``N`` is the number of documents, and should be sorted.

# Examples

Here's an example which loads a small subset of the LoTTe dataset defined in `short_collections.tsv` (see the `examples` folder in the package).

```julia-repl
julia> using ColBERT;

julia> dataroot = "downloads/lotte";

julia> dataset = "lifestyle";

julia> datasplit = "dev";

julia> path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")
"downloads/lotte/lifestyle/dev/short_collection.tsv"

julia> collection = Collection(path)
Collection at downloads/lotte/lifestyle/dev/short_collection.tsv with 10 passages.
```
"""
struct Collection
path::String
data::Vector{String}
Expand All @@ -12,10 +41,52 @@ function Collection(path::String)
Collection(path, file.text)
end

"""
get_chunksize(collection::Collection, nranks::Int)

Determine the size of chunks used to store the index, based on the size of the `collection` and the number of available GPUs.

# Arguments

- `collection::Collection`: The underlying collection of documents.
- `nranks::Int`: Number of available GPUs to compute the index. At this point, the package only supports `nranks = 1`.

# Examples

Continuing from the example from the [`Collection`](@ref) constructor:

```julia-repl
julia> get_chunksize(collection, 1)
11
```
"""
function get_chunksize(collection::Collection, nranks::Int)
Int(min(25000, 1 + floor(length(collection.data) / nranks)))
end

"""
enumerate_batches(collection::Collection; [chunksize, nranks])

Batch the `collection` into chunks containing tuples of the form `(chunk_idx, offset, passages)`, where `chunk_idx` is the index of the chunk, `offset` is the index of the first passsage in the chunk, and `passages` is a `Vector{String}` containing the passages in the chunk.

# Arguments

- `collection::Collection`: The collection to batch.
- `chunksize::Union{Int, Missing}`: The chunksize to use to batch the collection. Default `missing`. If this is `missing`, then `chunksize` is determined using [`get_chunksize`](@ref) based on the `collection` and `nranks`.
- `nranks::Union{Int, Missing}`: The number of available GPUs. Default `missing`. Currently the package only supports `nranks = 1`.

The `collection` is batched into chunks of uniform size (with the last chunk potentially having a smaller size).

# Examples

Continuing from the example in the [`Collection`](@ref) constructor.

```julia-repl
julia> enumerate_batches(collection; nranks = 1);

julia> enumerate_batches(collection; chunksize = 3);
```
"""
function enumerate_batches(
collection::Collection; chunksize::Union{Int, Missing} = missing,
nranks::Union{Int, Missing} = missing)
Expand All @@ -42,3 +113,7 @@ function enumerate_batches(
end
batches
end

function Base.show(io::IO, collection::Collection)
print(io, "Collection at $(collection.path) with $(length(collection.data)) passages.")
end
80 changes: 80 additions & 0 deletions src/indexing/codecs/residual.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
using .ColBERT: ColBERTConfig

"""
ResidualCodec(config::ColBERTConfig, centroids::Matrix{Float64}, avg_residual::Float64, bucket_cutoffs::Vector{Float64}, bucket_weights::Vector{Float64})

A struct that represents a compressor for ColBERT embeddings.

It stores information about the configuration of the model, the centroids used to quantize the residuals, the average residual value, and the cutoffs and weights used to determine which buckets each residual belongs to.

# Arguments

- `config`: A [`ColBERTConfig`](@ref), representing all configuration parameters related to various ColBERT components.
- `centroids`: A matrix of centroids used to quantize the residuals. Has shape `(D, N)`, where `D` is the embedding dimension and `N` is the number of clusters.
- `avg_residual`: The average residual value.
- `bucket_cutoffs`: A vector of cutoff values used to determine which buckets each residual belongs to.
- `bucket_weights`: A vector of weights used to determine the importance of each bucket.

# Returns

A `ResidualCodec` object.
"""
mutable struct ResidualCodec
config::ColBERTConfig
centroids::Matrix{Float64}
Expand All @@ -8,6 +27,21 @@ mutable struct ResidualCodec
bucket_weights::Vector{Float64}
end

"""
compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})

Compresses a matrix of embeddings into a vector of codes using the given [`ResidualCodec`](@ref), where the code for each embedding is its nearest centroid ID.

# Arguments

- `codec`: The [`ResidualCodec`](@ref) used to compress the embeddings.
- `embs`: The matrix of embeddings to be compressed.

# Returns

A vector of codes, where each code corresponds to the nearest centroid ID for the embedding.
```
"""
function compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})
codes = []

Expand All @@ -20,9 +54,24 @@ function compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})
offset += bsize
end

@assert length(codes) == size(embs)[2]
codes
end

"""
binarize(codec::ResidualCodec, residuals::Matrix{Float64})

Convert a matrix of residual vectors into a matrix of integer residual vector using `nbits` bits (specified by the underlying `config`).

# Arguments

- `codec`: A [`ResidualCodec`](@ref) object containing the compression information.
- `residuals`: The matrix of residuals to be converted.

# Returns

A matrix of compressed integer residual vectors.
"""
function binarize(codec::ResidualCodec, residuals::Matrix{Float64})
dim = codec.config.doc_settings.dim
nbits = codec.config.indexing_settings.nbits
Expand All @@ -46,6 +95,22 @@ function binarize(codec::ResidualCodec, residuals::Matrix{Float64})
residuals_packed = reshape(residuals_packed, (Int(dim / 8) * nbits, num_embeddings)) # reshape back to get compressions for each embedding
end

"""
compress(codec::ResidualCodec, embs::Matrix{Float64})

Compress a matrix of embeddings into a compact representation using the specified [`ResidualCodec`](@ref).

All embeddings are compressed to their nearest centroid IDs and their quantized residual vectors (where the quantization is done in `nbits` bits, specified by the `config` of `codec`). If `emb` denotes an embedding and `centroid` is is nearest centroid, the residual vector is defined to be `emb - centroid`.

# Arguments

- `codec`: A [`ResidualCodec`](@ref) object containing the centroids and other parameters for the compression algorithm.
- `embs`: The input embeddings to be compressed.

# Returns

A tuple containing a vector of codes and the compressed residuals matrix.
"""
function compress(codec::ResidualCodec, embs::Matrix{Float64})
codes, residuals = Vector{Int}(), Vector{Matrix{UInt8}}()

Expand All @@ -65,6 +130,21 @@ function compress(codec::ResidualCodec, embs::Matrix{Float64})
codes, residuals
end

"""
load_codes(codec::ResidualCodec, chunk_idx::Int)

Load the codes from disk for a given chunk index. The codes are stored in the file `<chunk_idx>.codes.jld2` located inside the
`index_path` provided by the configuration.

# Arguments

- `codec`: The [`ResidualCodec`](@ref) object containing the compression information.
- `chunk_idx`: The chunk index for which the codes should be loaded.

# Returns

A vector of codes for the specified chunk.
"""
function load_codes(codec::ResidualCodec, chunk_idx::Int)
codes_path = joinpath(codec.config.indexing_settings.index_path, "$(chunk_idx).codes.jld2")
codes = load(codes_path, "codes")
Expand Down
34 changes: 34 additions & 0 deletions src/indexing/collection_encoder.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
using ..ColBERT: ColBERTConfig

"""
CollectionEncoder(config::ColBERTConfig, checkpoint::Checkpoint)

Structure to represent an encoder used to encode document passages to their corresponding embeddings.

# Arguments

- `config`: The underlying [`ColBERTConfig`](@ref).
- `checkpoint`: The [`Checkpoint`](@ref) used by the model.

# Returns

A [`CollectionEncoder`](@ref).

"""
struct CollectionEncoder
config::ColBERTConfig
checkpoint::Checkpoint
end

"""
encode_passages(encoder::CollectionEncoder, passages::Vector{String})

Encode a list of passages using `encoder`.

The given `passages` are run through the underlying BERT model and the linear layer to generate the embeddings, after doing relevant document-specific preprocessing. See [`docFromText`](@ref) for more details.

# Arguments

- `encoder`: The encoder used to encode the passages.
- `passages`: A list of strings representing the passages to be encoded.

# Returns

A tuple `embs, doclens` where:

- `embs::Matrix{Float64}`: The full embedding matrix. Of shape `(D, N)`, where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages.
- `doclens::Vector{Int}`: A vector of document lengths for each passage, i.e the total number of attended tokens for each document passage.
"""
function encode_passages(encoder::CollectionEncoder, passages::Vector{String})
@info "Encoding $(length(passages)) passages."

Expand Down
Loading
Loading