Skip to content

Commit

Permalink
Add Pyserini onnx encoder support (#2113)
Browse files Browse the repository at this point in the history
Updated SimpleImpactSearcher so that Pyserini can search on the fly with an ONNX encoder by setting the "--onnx-encoder" flag
  • Loading branch information
ArthurChen189 authored May 9, 2023
1 parent 4aeb3ef commit 0009387
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ protected TokenStreamComponents createComponents(String fieldName) {

/**
* Tokenizes a String Object
* @param reader
* @return
* @throws IOException
* @param reader String Object
* @return Reader
* @throws IOException IOException
*/
public Reader tokenizeReader(Reader reader) throws IOException {
String targetString = IOUtils.toString(reader);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/anserini/eval/RelevanceJudgments.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private static String getCacheDir() {
*
* @param qrelsPath path to qrels file
* @return qrels file as a string
* @throws IOException
* @throws IOException if qrels file is not found
*/
public static String getQrelsResource(Path qrelsPath) throws IOException {
Path resultPath = qrelsPath;
Expand Down Expand Up @@ -181,7 +181,7 @@ public static Path getNewQrelAbsPath(Path qrelsPath) {
*
* @param qrelsPath path to qrels file
* @return path to qrels file
* @throws IOException
* @throws IOException if qrels file is not found
*/
public static Path downloadQrels(Path qrelsPath) throws IOException {
String qrelsURL = CLOUD_PATH + qrelsPath.getFileName().toString().toString();
Expand Down
43 changes: 40 additions & 3 deletions src/main/java/io/anserini/search/SimpleImpactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.anserini.rerank.ScoredDocuments;
import io.anserini.rerank.lib.ScoreTiesAdjusterReranker;
import io.anserini.search.query.BagOfWordsQueryGenerator;
import io.anserini.search.query.QueryEncoder;
import io.anserini.search.similarity.ImpactSimilarity;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -38,11 +39,14 @@
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.FSDirectory;

import ai.onnxruntime.OrtException;

import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionException;
Expand All @@ -66,7 +70,7 @@ public class SimpleImpactSearcher implements Closeable {
protected RerankerCascade cascade;
protected IndexSearcher searcher = null;
protected boolean backwardsCompatibilityLucene8;

private QueryEncoder queryEncoder = null;
/**
* This class is meant to serve as the bridge between Anserini and Pyserini.
* Note that we are adopting Python naming conventions here on purpose.
Expand Down Expand Up @@ -102,7 +106,7 @@ public SimpleImpactSearcher(String indexDir) throws IOException {
Path indexPath = Paths.get(indexDir);

if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
throw new IllegalArgumentException(indexDir + " does not exist or is not a directory.");
throw new IOException(indexDir + " does not exist or is not a directory.");
}

this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
Expand All @@ -119,6 +123,27 @@ public SimpleImpactSearcher(String indexDir) throws IOException {
cascade.add(new ScoreTiesAdjusterReranker());
}

/**
* Sets the query encoder
*
* @param encoder the query encoder
*/
public void set_onnx_query_encoder(String encoder) {
if (empty_encoder()) {
try {
this.queryEncoder = (QueryEncoder) Class.forName("io.anserini.search.query." + encoder + "QueryEncoder")
.getConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

private boolean empty_encoder(){
return this.queryEncoder == null;
}


/**
* Returns the number of documents in the index.
*
Expand Down Expand Up @@ -204,6 +229,19 @@ public Map<String, Result[]> batch_search(List<Map<String, Float>> queries,
return results;
}

/**
* Encodes the query using the onnx encoder
*
* @param queryString query string
* @throws OrtException if errors encountered during encoding
* @return encoded query
*/
public Map<String, Float> encode_with_onnx(String queryString) throws OrtException {
Map<String, Float> encodedQ = this.queryEncoder.getTokenWeightMap(queryString);
return encodedQ;
}


/**
* Searches the collection, returning 10 hits by default.
*
Expand All @@ -225,7 +263,6 @@ public Result[] search(Map<String, Float> q) throws IOException {
*/
public Result[] search(Map<String, Float> q, int k) throws IOException {
Query query = generator.buildQuery(Constants.CONTENTS, q);

return _search(query, k);
}

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/anserini/search/SimpleSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ public void set_rocchio(String collectionClass) {
* @param beta weight to assign to the relevant document vectors
* @param gamma weight to assign to the nonrelevant document vectors
* @param outputQuery flag to print original and expanded queries
* @param useNegative flag to use negative feedback
*/
public void set_rocchio(String collectionClass, int topFbTerms, int topFbDocs, int bottomFbTerms, int bottomFbDocs, float alpha, float beta, float gamma, boolean outputQuery, boolean useNegative) {
Class clazz = null;
Expand Down Expand Up @@ -797,6 +798,7 @@ public Document doc(String docid) {
* Batch version of {@link #doc(String)}.
*
* @param docids list of docids
* @param threads number of threads to use
* @return a map of docid to corresponding Lucene {@link Document}
*/
public Map<String, Document> batch_get_docs(List<String> docids, int threads) {
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/anserini/search/query/QueryEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,6 @@ static Map<String, Float> getTokenWeightMap(long[] indexes, float[] computedWeig
return tokenWeightMap;
}

public abstract Map<String, Float> getTokenWeightMap(String query) throws OrtException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ public SpladePlusPlusEnsembleDistilQueryEncoder() throws IOException, OrtExcepti

@Override
public String encode(String query) throws OrtException {
String encodedQuery = "";
Map<String, Float> tokenWeightMap = getTokenWeightMap(query);
return generateEncodedQuery(tokenWeightMap);
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
Expand All @@ -56,19 +61,20 @@ public String encode(String query) throws OrtException {
long[][] attentionMask = new long[1][queryTokenIds.length];
long[][] tokenTypeIds = new long[1][queryTokenIds.length];

// initialize attention mask with all 1s
// initialize attention mask with all 1s
Arrays.fill(attentionMask[0], 1);
inputs.put("input_ids", OnnxTensor.createTensor(environment, inputTokenIds));
inputs.put("token_type_ids", OnnxTensor.createTensor(environment, tokenTypeIds));
inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask));

Map<String, Float> tokenWeightMap = null;
try (OrtSession.Result results = session.run(inputs)) {
long[] indexes = (long[]) results.get("output_idx").get().getValue();
float[] weights = (float[]) results.get("output_weights").get().getValue();
Map<String, Float> tokenWeightMap = getTokenWeightMap(indexes, weights, vocab);
encodedQuery = generateEncodedQuery(tokenWeightMap);
tokenWeightMap = getTokenWeightMap(indexes, weights, vocab);
} catch (OrtException e) {
e.printStackTrace();
}
return encodedQuery;
return tokenWeightMap;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ public SpladePlusPlusSelfDistilQueryEncoder() throws IOException, OrtException {

@Override
public String encode(String query) throws OrtException {
String encodedQuery = "";
Map<String, Float> tokenWeightMap = getTokenWeightMap(query);
return generateEncodedQuery(tokenWeightMap);
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
Expand All @@ -60,14 +65,15 @@ public String encode(String query) throws OrtException {
inputs.put("input_ids", OnnxTensor.createTensor(environment, inputTokenIds));
inputs.put("token_type_ids", OnnxTensor.createTensor(environment, tokenTypeIds));
inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask));

Map<String, Float> tokenWeightMap = null;
try (OrtSession.Result results = session.run(inputs)) {
long[] indexes = (long[]) results.get("output_idx").get().getValue();
float[] weights = (float[]) results.get("output_weights").get().getValue();
Map<String, Float> tokenWeightMap = getTokenWeightMap(indexes, weights, vocab);
encodedQuery = generateEncodedQuery(tokenWeightMap);
tokenWeightMap = getTokenWeightMap(indexes, weights, vocab);
} catch (OrtException e) {
e.printStackTrace();
}
return encodedQuery;
return tokenWeightMap;
}

}
43 changes: 26 additions & 17 deletions src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,8 @@ public UniCoilQueryEncoder() throws IOException, OrtException {
@Override
public String encode(String query) throws OrtException {
String encodedQuery = "";
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
queryTokens.add("[SEP]");

Map<String, OnnxTensor> inputs = new HashMap<>();
long[] queryTokenIds = convertTokensToIds(tokenizer, queryTokens, vocab);
long[][] inputTokenIds = new long[1][queryTokenIds.length];
inputTokenIds[0] = queryTokenIds;
inputs.put("inputIds", OnnxTensor.createTensor(environment, inputTokenIds));

try (OrtSession.Result results = session.run(inputs)) {
float[] computedWeights = flatten(results.get(0).getValue());
Map<String, Float> tokenWeightMap = getTokenWeightMap(queryTokens, computedWeights);
encodedQuery = generateEncodedQuery(tokenWeightMap);
}
Map<String, Float> tokenWeightMap = getTokenWeightMap(query);
encodedQuery = generateEncodedQuery(tokenWeightMap);
return encodedQuery;
}

Expand All @@ -81,7 +67,7 @@ private float[] toArray(List<Float> input) {
return output;
}

Map<String, Float> getTokenWeightMap(List<String> tokens, float[] computedWeights) {
private Map<String, Float> getTokenWeightMap(List<String> tokens, float[] computedWeights) {
Map<String, Float> tokenWeightMap = new LinkedHashMap<>();
for (int i = 0; i < tokens.size(); ++i) {
String token = tokens.get(i);
Expand All @@ -100,4 +86,27 @@ Map<String, Float> getTokenWeightMap(List<String> tokens, float[] computedWeight
return tokenWeightMap;
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
queryTokens.add("[SEP]");

Map<String, OnnxTensor> inputs = new HashMap<>();
long[] queryTokenIds = convertTokensToIds(tokenizer, queryTokens, vocab);
long[][] inputTokenIds = new long[1][queryTokenIds.length];
inputTokenIds[0] = queryTokenIds;
inputs.put("inputIds", OnnxTensor.createTensor(environment, inputTokenIds));

Map<String, Float> tokenWeightMap = null;
try (OrtSession.Result results = session.run(inputs)) {
float[] computedWeights = flatten(results.get(0).getValue());
tokenWeightMap = getTokenWeightMap(queryTokens, computedWeights);
} catch (OrtException e) {
e.printStackTrace();
}
return tokenWeightMap;
}

}
13 changes: 7 additions & 6 deletions src/main/java/io/anserini/search/topicreader/TopicReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public SortedMap<K, Map<String, String>> read(String str) throws IOException {
*
* @param topics topics
* @param <K> type of topic id
* @throws IOException if error encountered reading topics
* @return evaluation topics
*/
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -171,7 +172,7 @@ public static <K> SortedMap<K, Map<String, String>> getTopicsByFile(String file)
*
* @param topics topics
* @return evaluation topics, with strings as topic ids
* @throws IOException
* @throws IOException if error encountered reading topics
*/
public static Map<String, Map<String, String>> getTopicsWithStringIds(Topics topics) throws IOException {
SortedMap<?, Map<String, String>> originalTopics = getTopics(topics);
Expand Down Expand Up @@ -226,9 +227,9 @@ private static String getCacheDir() {

/**
* Downloads the topics from the cloud and returns the path to the local copy
* @param topicPath
* @param topicPath Path to the topics
* @return Path to the local copy of the topics
* @throws IOException
* @throws IOException if error encountered downloading topics
*/
public static Path getTopicsFromCloud(Path topicPath) throws IOException{
String topicURL = CLOUD_PATH + topicPath.getFileName().toString();
Expand All @@ -249,9 +250,9 @@ public static Path getNewTopicsAbsPath(Path topicPath){

/**
* Returns the path to the topic file. If the topic file is not in the list of known topics, we assume it is a local file.
* @param topicPath
* @return
* @throws IOException
* @param topicPath Path to the topic file
* @return Path to the topic file
* @throws IOException if error encountered reading topics
*/
public static Path getTopicPath(Path topicPath) throws IOException{
if (Files.exists(topicPath)) {
Expand Down
20 changes: 20 additions & 0 deletions src/test/java/io/anserini/search/SimpleImpactSearcherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
import java.util.Map;

public class SimpleImpactSearcherTest extends IndexerTestBase {

private static Map<String, Float> EXPECTED_ENCODED_QUERY = new HashMap<>();

static {
EXPECTED_ENCODED_QUERY.put("here", 3.05345f);
EXPECTED_ENCODED_QUERY.put("a", 0.59636426f);
EXPECTED_ENCODED_QUERY.put("test", 2.9012794f);
}

@Test
public void testGetDoc() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
Expand Down Expand Up @@ -195,4 +204,15 @@ public void testTotalNumDocuments() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
assertEquals(3 ,searcher.get_total_num_docs());
}

@Test
public void testOnnxEncoder() throws Exception{
SimpleImpactSearcher searcher = new SimpleImpactSearcher();
searcher.set_onnx_query_encoder("SpladePlusPlusEnsembleDistil");

Map<String, Float> encoded_query = searcher.encode_with_onnx("here is a test");
assertEquals(encoded_query.get("here"), EXPECTED_ENCODED_QUERY.get("here"));
assertEquals(encoded_query.get("a"), EXPECTED_ENCODED_QUERY.get("a"));
assertEquals(encoded_query.get("test"), EXPECTED_ENCODED_QUERY.get("test"));
}
}

0 comments on commit 0009387

Please # to comment.