-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfaiss_embedding_writer.rs
131 lines (97 loc) · 3.97 KB
/
faiss_embedding_writer.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use rustserini::encode::auto::AutoDocumentEncoder;
use rustserini::encode::base::{DocumentEncoder, RepresentationWriter};
use rustserini::encode::vector_writer::{JsonlCollectionIterator, FaissRepresentationWriter};
use std::collections::HashMap;
use std::time::Instant;
use clap::{ArgAction, Parser};
/// Simple program to encode a corpus and store the embeddings in a jsonl file
/// Download the msmarco passage dataset using the below command:
/// mkdir corpus/msmarco-passage
/// wget https://huggingface.co/datasets/Tevatron/msmarco-passage-corpus/resolve/main/corpus.jsonl.gz -P corpus/msmarco-passage
/// cargo run --example json_embedding_writer -- --corpus corpus/msmarco-passage/corpus.jsonl.gz --embeddings-dir corpus/msmarco-passage --encoder bert-base-uncased --tokenizer bert-base-uncased
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Directory that contains corpus files to be encoded, in jsonl format.
#[arg(short, long)]
corpus: String,
/// Fields that contents in jsonl has (in order) separated by comma.
#[arg(short, long, default_value = "text")]
fields: String,
/// delimiter for the fields
#[arg(short, long, default_value = "\n")]
delimiter: String,
/// shard-id 0-based
#[arg(short, long, default_value_t = 0)]
shard_id: u8,
/// number of shards
#[arg(long, default_value_t = 1)]
shard_num: u8,
/// directory to store encoded corpus
#[arg(short, long, required = true)]
embeddings_dir: String,
/// Whether to store the embeddings in a faiss index or in a jsonl file
#[arg(long, action=ArgAction::SetFalse)]
to_faiss: bool,
/// Encoder name or path
#[arg(long)]
encoder: String,
/// Encoder Revision
#[arg(long, default_value = "main")]
revision: String,
/// Tokenizer name or path
#[arg(long)]
tokenizer: String,
/// Batch size for encoding
#[arg(short, long, default_value_t = 4)]
batch_size: usize,
/// GPU Device ==> cpu or cuda:0
#[arg(long, default_value = "cpu")]
device: String,
/// Whether to use fp16
#[arg(long, action=ArgAction::SetTrue)]
fp16: bool,
/// max length of the input
#[arg(short, long, default_value_t = 512)]
max_length: u16,
/// Embedding dimension
#[arg(long, default_value_t = 768)]
embedding_dim: u32,
}
fn sanitize_string(s: &str) -> String {
s.replace("\"", "").replace("\\", "")
}
fn main() -> anyhow::Result<()> {
let start = Instant::now();
let args = Args::parse();
let fields: Vec<String> = args.fields.split(',').map(|s| s.to_string()).collect();
let mut iterator: JsonlCollectionIterator =
JsonlCollectionIterator::new(fields, "id".to_string(), args.delimiter, args.batch_size);
let _ = iterator.load(args.corpus);
println!("Initialize a representation writer and open a file to store the embeddings");
let mut writer = FaissRepresentationWriter::new(&args.embeddings_dir, args.embedding_dim);
let _ = writer.open_file();
let encoder = AutoDocumentEncoder::new(
&args.encoder,
&args.revision,
);
let mut counter: usize = 0;
for batch in iterator.iter() {
let mut batch_info = HashMap::new();
let batch_text: Vec<String> = batch["text"].iter().map(|x| sanitize_string(x)).collect();
let batch_id: Vec<String> = batch["id"].iter().map(|x| sanitize_string(x)).collect();
let embeddings = &encoder.encode(&batch_text, None, "cls")?;
let mut embeddings: Vec<f32> = embeddings.flatten_all()?.to_vec1::<f32>()?;
batch_info.insert("text", batch_text);
batch_info.insert("id", batch_id);
let _ = &writer.write(&batch_info, &mut embeddings);
counter += 1;
println!("Batch {} encoded", counter);
break;
}
writer.save_index()?;
writer.save_docids()?;
let duration = start.elapsed();
println!("Time elapsed in expensive_function() is: {:?}", duration);
Ok(())
}