From 0009387359e9c9c29b7fee4658f7f5d8c107968e Mon Sep 17 00:00:00 2001 From: Arthur Chen <36494787+ArthurChen189@users.noreply.github.com> Date: Tue, 9 May 2023 08:54:38 -0400 Subject: [PATCH] Add Pyserini onnx encoder support (#2113) Updated SimpleImpactSearcher so that Pyserini can search on the fly with an ONNX encoder by setting the "--onnx-encoder" flag --- .../HuggingFaceTokenizerAnalyzer.java | 6 +-- .../io/anserini/eval/RelevanceJudgments.java | 4 +- .../anserini/search/SimpleImpactSearcher.java | 43 +++++++++++++++++-- .../io/anserini/search/SimpleSearcher.java | 2 + .../anserini/search/query/QueryEncoder.java | 2 + ...adePlusPlusEnsembleDistilQueryEncoder.java | 18 +++++--- .../SpladePlusPlusSelfDistilQueryEncoder.java | 16 ++++--- .../search/query/UniCoilQueryEncoder.java | 43 +++++++++++-------- .../search/topicreader/TopicReader.java | 13 +++--- .../search/SimpleImpactSearcherTest.java | 20 +++++++++ 10 files changed, 125 insertions(+), 42 deletions(-) diff --git a/src/main/java/io/anserini/analysis/HuggingFaceTokenizerAnalyzer.java b/src/main/java/io/anserini/analysis/HuggingFaceTokenizerAnalyzer.java index 9a09b2b499..2d7da42eda 100644 --- a/src/main/java/io/anserini/analysis/HuggingFaceTokenizerAnalyzer.java +++ b/src/main/java/io/anserini/analysis/HuggingFaceTokenizerAnalyzer.java @@ -68,9 +68,9 @@ protected TokenStreamComponents createComponents(String fieldName) { /** * Tokenizes a String Object - * @param reader - * @return - * @throws IOException + * @param reader String Object + * @return Reader + * @throws IOException IOException */ public Reader tokenizeReader(Reader reader) throws IOException { String targetString = IOUtils.toString(reader); diff --git a/src/main/java/io/anserini/eval/RelevanceJudgments.java b/src/main/java/io/anserini/eval/RelevanceJudgments.java index 4928a38010..1dce504d31 100644 --- a/src/main/java/io/anserini/eval/RelevanceJudgments.java +++ b/src/main/java/io/anserini/eval/RelevanceJudgments.java @@ -129,7 +129,7 @@ private static String getCacheDir() { * * @param qrelsPath path to qrels file * @return qrels file as a string - * @throws IOException + * @throws IOException if qrels file is not found */ public static String getQrelsResource(Path qrelsPath) throws IOException { Path resultPath = qrelsPath; @@ -181,7 +181,7 @@ public static Path getNewQrelAbsPath(Path qrelsPath) { * * @param qrelsPath path to qrels file * @return path to qrels file - * @throws IOException + * @throws IOException if qrels file is not found */ public static Path downloadQrels(Path qrelsPath) throws IOException { String qrelsURL = CLOUD_PATH + qrelsPath.getFileName().toString().toString(); diff --git a/src/main/java/io/anserini/search/SimpleImpactSearcher.java b/src/main/java/io/anserini/search/SimpleImpactSearcher.java index d3c136c515..9f353ccfd7 100644 --- a/src/main/java/io/anserini/search/SimpleImpactSearcher.java +++ b/src/main/java/io/anserini/search/SimpleImpactSearcher.java @@ -23,6 +23,7 @@ import io.anserini.rerank.ScoredDocuments; import io.anserini.rerank.lib.ScoreTiesAdjusterReranker; import io.anserini.search.query.BagOfWordsQueryGenerator; +import io.anserini.search.query.QueryEncoder; import io.anserini.search.similarity.ImpactSimilarity; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -38,11 +39,14 @@ import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.FSDirectory; +import ai.onnxruntime.OrtException; + import java.io.Closeable; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletionException; @@ -66,7 +70,7 @@ public class SimpleImpactSearcher implements Closeable { protected RerankerCascade cascade; protected IndexSearcher searcher = null; protected boolean backwardsCompatibilityLucene8; - + private QueryEncoder queryEncoder = null; /** * This class is meant to serve as the bridge between Anserini and Pyserini. * Note that we are adopting Python naming conventions here on purpose. @@ -102,7 +106,7 @@ public SimpleImpactSearcher(String indexDir) throws IOException { Path indexPath = Paths.get(indexDir); if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) { - throw new IllegalArgumentException(indexDir + " does not exist or is not a directory."); + throw new IOException(indexDir + " does not exist or is not a directory."); } this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); @@ -119,6 +123,27 @@ public SimpleImpactSearcher(String indexDir) throws IOException { cascade.add(new ScoreTiesAdjusterReranker()); } + /** + * Sets the query encoder + * + * @param encoder the query encoder + */ + public void set_onnx_query_encoder(String encoder) { + if (empty_encoder()) { + try { + this.queryEncoder = (QueryEncoder) Class.forName("io.anserini.search.query." + encoder + "QueryEncoder") + .getConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + private boolean empty_encoder(){ + return this.queryEncoder == null; + } + + /** * Returns the number of documents in the index. * @@ -204,6 +229,19 @@ public Map batch_search(List> queries, return results; } + /** + * Encodes the query using the onnx encoder + * + * @param queryString query string + * @throws OrtException if errors encountered during encoding + * @return encoded query + */ + public Map encode_with_onnx(String queryString) throws OrtException { + Map encodedQ = this.queryEncoder.getTokenWeightMap(queryString); + return encodedQ; + } + + /** * Searches the collection, returning 10 hits by default. * @@ -225,7 +263,6 @@ public Result[] search(Map q) throws IOException { */ public Result[] search(Map q, int k) throws IOException { Query query = generator.buildQuery(Constants.CONTENTS, q); - return _search(query, k); } diff --git a/src/main/java/io/anserini/search/SimpleSearcher.java b/src/main/java/io/anserini/search/SimpleSearcher.java index ddbc4dfacf..fc81d4c641 100644 --- a/src/main/java/io/anserini/search/SimpleSearcher.java +++ b/src/main/java/io/anserini/search/SimpleSearcher.java @@ -390,6 +390,7 @@ public void set_rocchio(String collectionClass) { * @param beta weight to assign to the relevant document vectors * @param gamma weight to assign to the nonrelevant document vectors * @param outputQuery flag to print original and expanded queries + * @param useNegative flag to use negative feedback */ public void set_rocchio(String collectionClass, int topFbTerms, int topFbDocs, int bottomFbTerms, int bottomFbDocs, float alpha, float beta, float gamma, boolean outputQuery, boolean useNegative) { Class clazz = null; @@ -797,6 +798,7 @@ public Document doc(String docid) { * Batch version of {@link #doc(String)}. * * @param docids list of docids + * @param threads number of threads to use * @return a map of docid to corresponding Lucene {@link Document} */ public Map batch_get_docs(List docids, int threads) { diff --git a/src/main/java/io/anserini/search/query/QueryEncoder.java b/src/main/java/io/anserini/search/query/QueryEncoder.java index fabf84e099..8b91c60aae 100644 --- a/src/main/java/io/anserini/search/query/QueryEncoder.java +++ b/src/main/java/io/anserini/search/query/QueryEncoder.java @@ -87,4 +87,6 @@ static Map getTokenWeightMap(long[] indexes, float[] computedWeig return tokenWeightMap; } + public abstract Map getTokenWeightMap(String query) throws OrtException; + } \ No newline at end of file diff --git a/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java b/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java index 4a5641490d..d4afd70bff 100644 --- a/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java @@ -42,7 +42,12 @@ public SpladePlusPlusEnsembleDistilQueryEncoder() throws IOException, OrtExcepti @Override public String encode(String query) throws OrtException { - String encodedQuery = ""; + Map tokenWeightMap = getTokenWeightMap(query); + return generateEncodedQuery(tokenWeightMap); + } + + @Override + public Map getTokenWeightMap(String query) throws OrtException { List queryTokens = new ArrayList<>(); queryTokens.add("[CLS]"); queryTokens.addAll(tokenizer.tokenize(query)); @@ -56,19 +61,20 @@ public String encode(String query) throws OrtException { long[][] attentionMask = new long[1][queryTokenIds.length]; long[][] tokenTypeIds = new long[1][queryTokenIds.length]; - // initialize attention mask with all 1s + // initialize attention mask with all 1s Arrays.fill(attentionMask[0], 1); inputs.put("input_ids", OnnxTensor.createTensor(environment, inputTokenIds)); inputs.put("token_type_ids", OnnxTensor.createTensor(environment, tokenTypeIds)); inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask)); - + Map tokenWeightMap = null; try (OrtSession.Result results = session.run(inputs)) { long[] indexes = (long[]) results.get("output_idx").get().getValue(); float[] weights = (float[]) results.get("output_weights").get().getValue(); - Map tokenWeightMap = getTokenWeightMap(indexes, weights, vocab); - encodedQuery = generateEncodedQuery(tokenWeightMap); + tokenWeightMap = getTokenWeightMap(indexes, weights, vocab); + } catch (OrtException e) { + e.printStackTrace(); } - return encodedQuery; + return tokenWeightMap; } } \ No newline at end of file diff --git a/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java b/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java index 0bf5a233ab..d8d8ba48aa 100644 --- a/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java @@ -41,7 +41,12 @@ public SpladePlusPlusSelfDistilQueryEncoder() throws IOException, OrtException { @Override public String encode(String query) throws OrtException { - String encodedQuery = ""; + Map tokenWeightMap = getTokenWeightMap(query); + return generateEncodedQuery(tokenWeightMap); + } + + @Override + public Map getTokenWeightMap(String query) throws OrtException { List queryTokens = new ArrayList<>(); queryTokens.add("[CLS]"); queryTokens.addAll(tokenizer.tokenize(query)); @@ -60,14 +65,15 @@ public String encode(String query) throws OrtException { inputs.put("input_ids", OnnxTensor.createTensor(environment, inputTokenIds)); inputs.put("token_type_ids", OnnxTensor.createTensor(environment, tokenTypeIds)); inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask)); - + Map tokenWeightMap = null; try (OrtSession.Result results = session.run(inputs)) { long[] indexes = (long[]) results.get("output_idx").get().getValue(); float[] weights = (float[]) results.get("output_weights").get().getValue(); - Map tokenWeightMap = getTokenWeightMap(indexes, weights, vocab); - encodedQuery = generateEncodedQuery(tokenWeightMap); + tokenWeightMap = getTokenWeightMap(indexes, weights, vocab); + } catch (OrtException e) { + e.printStackTrace(); } - return encodedQuery; + return tokenWeightMap; } } \ No newline at end of file diff --git a/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java b/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java index 78a7fe67aa..95bf44d874 100644 --- a/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java @@ -41,22 +41,8 @@ public UniCoilQueryEncoder() throws IOException, OrtException { @Override public String encode(String query) throws OrtException { String encodedQuery = ""; - List queryTokens = new ArrayList<>(); - queryTokens.add("[CLS]"); - queryTokens.addAll(tokenizer.tokenize(query)); - queryTokens.add("[SEP]"); - - Map inputs = new HashMap<>(); - long[] queryTokenIds = convertTokensToIds(tokenizer, queryTokens, vocab); - long[][] inputTokenIds = new long[1][queryTokenIds.length]; - inputTokenIds[0] = queryTokenIds; - inputs.put("inputIds", OnnxTensor.createTensor(environment, inputTokenIds)); - - try (OrtSession.Result results = session.run(inputs)) { - float[] computedWeights = flatten(results.get(0).getValue()); - Map tokenWeightMap = getTokenWeightMap(queryTokens, computedWeights); - encodedQuery = generateEncodedQuery(tokenWeightMap); - } + Map tokenWeightMap = getTokenWeightMap(query); + encodedQuery = generateEncodedQuery(tokenWeightMap); return encodedQuery; } @@ -81,7 +67,7 @@ private float[] toArray(List input) { return output; } - Map getTokenWeightMap(List tokens, float[] computedWeights) { + private Map getTokenWeightMap(List tokens, float[] computedWeights) { Map tokenWeightMap = new LinkedHashMap<>(); for (int i = 0; i < tokens.size(); ++i) { String token = tokens.get(i); @@ -100,4 +86,27 @@ Map getTokenWeightMap(List tokens, float[] computedWeight return tokenWeightMap; } + @Override + public Map getTokenWeightMap(String query) throws OrtException { + List queryTokens = new ArrayList<>(); + queryTokens.add("[CLS]"); + queryTokens.addAll(tokenizer.tokenize(query)); + queryTokens.add("[SEP]"); + + Map inputs = new HashMap<>(); + long[] queryTokenIds = convertTokensToIds(tokenizer, queryTokens, vocab); + long[][] inputTokenIds = new long[1][queryTokenIds.length]; + inputTokenIds[0] = queryTokenIds; + inputs.put("inputIds", OnnxTensor.createTensor(environment, inputTokenIds)); + + Map tokenWeightMap = null; + try (OrtSession.Result results = session.run(inputs)) { + float[] computedWeights = flatten(results.get(0).getValue()); + tokenWeightMap = getTokenWeightMap(queryTokens, computedWeights); + } catch (OrtException e) { + e.printStackTrace(); + } + return tokenWeightMap; + } + } \ No newline at end of file diff --git a/src/main/java/io/anserini/search/topicreader/TopicReader.java b/src/main/java/io/anserini/search/topicreader/TopicReader.java index 072ebc099c..49359092dd 100755 --- a/src/main/java/io/anserini/search/topicreader/TopicReader.java +++ b/src/main/java/io/anserini/search/topicreader/TopicReader.java @@ -116,6 +116,7 @@ public SortedMap> read(String str) throws IOException { * * @param topics topics * @param type of topic id + * @throws IOException if error encountered reading topics * @return evaluation topics */ @SuppressWarnings("unchecked") @@ -171,7 +172,7 @@ public static SortedMap> getTopicsByFile(String file) * * @param topics topics * @return evaluation topics, with strings as topic ids - * @throws IOException + * @throws IOException if error encountered reading topics */ public static Map> getTopicsWithStringIds(Topics topics) throws IOException { SortedMap> originalTopics = getTopics(topics); @@ -226,9 +227,9 @@ private static String getCacheDir() { /** * Downloads the topics from the cloud and returns the path to the local copy - * @param topicPath + * @param topicPath Path to the topics * @return Path to the local copy of the topics - * @throws IOException + * @throws IOException if error encountered downloading topics */ public static Path getTopicsFromCloud(Path topicPath) throws IOException{ String topicURL = CLOUD_PATH + topicPath.getFileName().toString(); @@ -249,9 +250,9 @@ public static Path getNewTopicsAbsPath(Path topicPath){ /** * Returns the path to the topic file. If the topic file is not in the list of known topics, we assume it is a local file. - * @param topicPath - * @return - * @throws IOException + * @param topicPath Path to the topic file + * @return Path to the topic file + * @throws IOException if error encountered reading topics */ public static Path getTopicPath(Path topicPath) throws IOException{ if (Files.exists(topicPath)) { diff --git a/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java index dbe8497931..451ab1aaee 100644 --- a/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java +++ b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java @@ -25,6 +25,15 @@ import java.util.Map; public class SimpleImpactSearcherTest extends IndexerTestBase { + + private static Map EXPECTED_ENCODED_QUERY = new HashMap<>(); + + static { + EXPECTED_ENCODED_QUERY.put("here", 3.05345f); + EXPECTED_ENCODED_QUERY.put("a", 0.59636426f); + EXPECTED_ENCODED_QUERY.put("test", 2.9012794f); + } + @Test public void testGetDoc() throws Exception { SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); @@ -195,4 +204,15 @@ public void testTotalNumDocuments() throws Exception { SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); assertEquals(3 ,searcher.get_total_num_docs()); } + + @Test + public void testOnnxEncoder() throws Exception{ + SimpleImpactSearcher searcher = new SimpleImpactSearcher(); + searcher.set_onnx_query_encoder("SpladePlusPlusEnsembleDistil"); + + Map encoded_query = searcher.encode_with_onnx("here is a test"); + assertEquals(encoded_query.get("here"), EXPECTED_ENCODED_QUERY.get("here")); + assertEquals(encoded_query.get("a"), EXPECTED_ENCODED_QUERY.get("a")); + assertEquals(encoded_query.get("test"), EXPECTED_ENCODED_QUERY.get("test")); + } }