diff --git a/src/main/java/io/anserini/search/SimpleImpactSearcher.java b/src/main/java/io/anserini/search/SimpleImpactSearcher.java new file mode 100644 index 0000000000..2616571580 --- /dev/null +++ b/src/main/java/io/anserini/search/SimpleImpactSearcher.java @@ -0,0 +1,373 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search; + +import io.anserini.index.IndexArgs; +import io.anserini.index.IndexReaderUtils; +import io.anserini.rerank.RerankerCascade; +import io.anserini.rerank.RerankerContext; +import io.anserini.rerank.ScoredDocuments; +import io.anserini.rerank.lib.ScoreTiesAdjusterReranker; +import io.anserini.search.query.BagOfWordsQueryGenerator; +import io.anserini.search.similarity.ImpactSimilarity; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.store.FSDirectory; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Class that exposes basic search functionality, designed specifically to provide the bridge between Java and Python + * via pyjnius. + */ +public class SimpleImpactSearcher implements Closeable { + public static final Sort BREAK_SCORE_TIES_BY_DOCID = + new Sort(SortField.FIELD_SCORE, new SortField(IndexArgs.ID, SortField.Type.STRING_VAL)); + private static final Logger LOG = LogManager.getLogger(SimpleImpactSearcher.class); + + protected IndexReader reader; + protected Similarity similarity; + protected BagOfWordsQueryGenerator generator; + protected RerankerCascade cascade; + protected IndexSearcher searcher = null; + + /** + * This class is meant to serve as the bridge between Anserini and Pyserini. + * Note that we are adopting Python naming conventions here on purpose. + */ + public class Result { + public String docid; + public int lucene_docid; + public float score; + public String contents; + public String raw; + public Document lucene_document; // Since this is for Python access, we're using Python naming conventions. + + public Result(String docid, int lucene_docid, float score, String contents, String raw, Document lucene_document) { + this.docid = docid; + this.lucene_docid = lucene_docid; + this.score = score; + this.contents = contents; + this.raw = raw; + this.lucene_document = lucene_document; + } + } + + protected SimpleImpactSearcher() { + } + + /** + * Creates a {@code SimpleImpactSearcher}. + * + * @param indexDir index directory + * @throws IOException if errors encountered during initialization + */ + public SimpleImpactSearcher(String indexDir) throws IOException { + Path indexPath = Paths.get(indexDir); + + if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) { + throw new IllegalArgumentException(indexDir + " does not exist or is not a directory."); + } + + this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); + // Default to using ImpactSimilarity. + this.similarity = new ImpactSimilarity(); + this.generator = new BagOfWordsQueryGenerator(); + cascade = new RerankerCascade(); + cascade.add(new ScoreTiesAdjusterReranker()); + } + + + /** + * Returns the number of documents in the index. + * + * @return the number of documents in the index + */ + public int getTotalNumDocuments(){ + // Create an IndexSearch only once. Note that the object is thread safe. + if (searcher == null) { + searcher = new IndexSearcher(reader); + searcher.setSimilarity(similarity); + } + + return searcher.getIndexReader().maxDoc(); + } + + /** + * Closes this searcher. + */ + @Override + public void close() throws IOException { + try { + reader.close(); + } catch (Exception e) { + // Eat any exceptions. + return; + } + } + + /** + * Searches in batch + * + * @param queries list of queries + * @param qids list of unique query ids + * @param k number of hits + * @param threads number of threads + * @return a map of query id to search results + */ + public Map batchSearch(List> queries, List qids, int k, int threads) { + // Create the IndexSearcher here, if needed. We do it here because if we leave the creation to the search + // method, we might end up with a race condition as multiple threads try to concurrently create the IndexSearcher. + if (searcher == null) { + searcher = new IndexSearcher(reader); + searcher.setSimilarity(similarity); + } + + ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(threads); + ConcurrentHashMap results = new ConcurrentHashMap<>(); + + long startTime = System.nanoTime(); + AtomicLong index = new AtomicLong(); + int queryCnt = queries.size(); + for (int q = 0; q < queryCnt; ++q) { + Map query = queries.get(q); + String qid = qids.get(q); + executor.execute(() -> { + try { + results.put(qid, search(query, k)); + } catch (IOException e) { + throw new CompletionException(e); + } + // logging for speed + Long lineNumber = index.incrementAndGet(); + if (lineNumber % 100 == 0) { + double timePerQuery = (double) (System.nanoTime() - startTime) / (lineNumber + 1) / 1e9; + LOG.info(String.format("Retrieving query " + lineNumber + " (%.3f s/query)", timePerQuery)); + } + }); + } + + executor.shutdown(); + + try { + // Wait for existing tasks to terminate + while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { + LOG.info(String.format("%.2f percent completed", + (double) executor.getCompletedTaskCount() / queries.size() * 100.0d)); + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + executor.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + + if (queryCnt != executor.getCompletedTaskCount()) { + throw new RuntimeException("queryCount = " + queryCnt + + " is not equal to completedTaskCount = " + executor.getCompletedTaskCount()); + } + + return results; + } + + /** + * Searches the collection, returning 10 hits by default. + * + * @param q query + * @return array of search results + * @throws IOException if error encountered during search + */ + public Result[] search(Map q) throws IOException { + return search(q, 10); + } + + /** + * Searches the collection. + * + * @param q query + * @param k number of hits + * @return array of search results + * @throws IOException if error encountered during search + */ + public Result[] search(Map q, int k) throws IOException { + Query query = generator.buildQuery(IndexArgs.CONTENTS, q); + + return _search(query, k); + } + + // internal implementation + protected Result[] _search(Query query, int k) throws IOException { + // Create an IndexSearch only once. Note that the object is thread safe. + if (searcher == null) { + searcher = new IndexSearcher(reader); + searcher.setSimilarity(similarity); + } + + SearchArgs searchArgs = new SearchArgs(); + searchArgs.arbitraryScoreTieBreak = false; + searchArgs.hits = k; + + TopDocs rs; + RerankerContext context; + rs = searcher.search(query, k, BREAK_SCORE_TIES_BY_DOCID, true); + context = new RerankerContext<>(searcher, null, query, null, + null, null, null, searchArgs); + + ScoredDocuments hits = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); + + Result[] results = new Result[hits.ids.length]; + for (int i = 0; i < hits.ids.length; i++) { + Document doc = hits.documents[i]; + String docid = doc.getField(IndexArgs.ID).stringValue(); + + IndexableField field; + field = doc.getField(IndexArgs.CONTENTS); + String contents = field == null ? null : field.stringValue(); + + field = doc.getField(IndexArgs.RAW); + String raw = field == null ? null : field.stringValue(); + + results[i] = new Result(docid, hits.ids[i], hits.scores[i], contents, raw, doc); + } + + return results; + } + + /** + * Fetches the Lucene {@link Document} based on an internal Lucene docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param ldocid internal Lucene docid + * @return corresponding Lucene {@link Document} + */ + public Document document(int ldocid) { + try { + return reader.document(ldocid); + } catch (Exception e) { + // Eat any exceptions and just return null. + return null; + } + } + + /** + * Returns the Lucene {@link Document} based on a collection docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param docid collection docid + * @return corresponding Lucene {@link Document} + */ + public Document document(String docid) { + return IndexReaderUtils.document(reader, docid); + } + + /** + * Fetches the Lucene {@link Document} based on some field other than its unique collection docid. + * For example, scientific articles might have DOIs. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param field field + * @param id unique id + * @return corresponding Lucene {@link Document} based on the value of a specific field + */ + public Document documentByField(String field, String id) { + return IndexReaderUtils.documentByField(reader, field, id); + } + + /** + * Returns the "contents" field of a document based on an internal Lucene docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param ldocid internal Lucene docid + * @return the "contents" field the document + */ + public String documentContents(int ldocid) { + try { + return reader.document(ldocid).get(IndexArgs.CONTENTS); + } catch (Exception e) { + // Eat any exceptions and just return null. + return null; + } + } + + /** + * Returns the "contents" field of a document based on a collection docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param docid collection docid + * @return the "contents" field the document + */ + public String documentContents(String docid) { + return IndexReaderUtils.documentContents(reader, docid); + } + + /** + * Returns the "raw" field of a document based on an internal Lucene docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param ldocid internal Lucene docid + * @return the "raw" field the document + */ + public String documentRaw(int ldocid) { + try { + return reader.document(ldocid).get(IndexArgs.RAW); + } catch (Exception e) { + // Eat any exceptions and just return null. + return null; + } + } + + /** + * Returns the "raw" field of a document based on a collection docid. + * The method is named to be consistent with Lucene's {@link IndexReader#document(int)}, contra Java's standard + * method naming conventions. + * + * @param docid collection docid + * @return the "raw" field the document + */ + public String documentRaw(String docid) { + return IndexReaderUtils.documentRaw(reader, docid); + } + } + \ No newline at end of file diff --git a/src/main/java/io/anserini/search/query/BagOfWordsQueryGenerator.java b/src/main/java/io/anserini/search/query/BagOfWordsQueryGenerator.java index 610ff56a1d..e1bf6e16b1 100644 --- a/src/main/java/io/anserini/search/query/BagOfWordsQueryGenerator.java +++ b/src/main/java/io/anserini/search/query/BagOfWordsQueryGenerator.java @@ -37,10 +37,12 @@ public class BagOfWordsQueryGenerator extends QueryGenerator { @Override public Query buildQuery(String field, Analyzer analyzer, String queryText) { List tokens = AnalyzerUtils.analyze(analyzer, queryText); - Map collect = tokens.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + Map collect = tokens.stream() + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (String t : collect.keySet()){ - builder.add(new BoostQuery(new TermQuery(new Term(field, t)), (float)collect.get(t)), BooleanClause.Occur.SHOULD); + for (String t : collect.keySet()) { + builder.add(new BoostQuery(new TermQuery(new Term(field, t)), (float) collect.get(t)), + BooleanClause.Occur.SHOULD); } return builder.build(); } @@ -57,4 +59,26 @@ public Query buildQuery(Map fields, Analyzer analyzer, String que } return builder.build(); } + + public Query buildQuery(String field, Map queryTokenWeights) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (String t : queryTokenWeights.keySet()) { + builder.add(new BoostQuery(new TermQuery(new Term(field, t)), queryTokenWeights.get(t)), + BooleanClause.Occur.SHOULD); + } + return builder.build(); + } + + public Query buildQuery(Map fields, Map queryTokenWeights) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (Map.Entry entry : fields.entrySet()) { + String field = entry.getKey(); + float boost = entry.getValue(); + + Query clause = buildQuery(field, queryTokenWeights); + builder.add(new BoostQuery(clause, boost), BooleanClause.Occur.SHOULD); + } + return builder.build(); + } + } diff --git a/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java new file mode 100644 index 0000000000..aa4b2c3822 --- /dev/null +++ b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java @@ -0,0 +1,205 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search; + +import io.anserini.IndexerTestBase; +import io.anserini.index.IndexArgs; +import org.junit.Test; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SimpleImpactSearcherTest extends IndexerTestBase { + + @Test + public void testGetDoc() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + assertEquals("here is some text here is some more text. city.", + searcher.document(0).get("contents")); + assertEquals("more texts", searcher.document(1).get("contents")); + assertEquals("here is a test", searcher.document(2).get("contents")); + assertNull(searcher.document(3)); + + assertEquals("here is some text here is some more text. city.", + searcher.document("doc1").get("contents")); + assertEquals("more texts", searcher.document("doc2").get("contents")); + assertEquals("here is a test", searcher.document("doc3").get("contents")); + assertNull(searcher.document(3)); + + searcher.close(); + } + + @Test + public void testGetDocByField() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + assertEquals("here is some text here is some more text. city.", + searcher.documentByField("id", "doc1").get("contents")); + assertEquals("more texts", searcher.documentByField("id", "doc2").get("contents")); + assertEquals("here is a test", searcher.documentByField("id", "doc3").get("contents")); + + searcher.close(); + } + + @Test + public void testGetContents() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + assertEquals("here is some text here is some more text. city.", + searcher.documentContents(0)); + assertEquals("more texts", + searcher.documentContents(1)); + assertEquals("here is a test", + searcher.documentContents(2)); + assertNull(searcher.document(3)); + + assertEquals("here is some text here is some more text. city.", + searcher.documentContents("doc1")); + assertEquals("more texts", + searcher.documentContents("doc2")); + assertEquals("here is a test", + searcher.documentContents("doc3")); + assertNull(searcher.documentContents("doc42")); + + searcher.close(); + } + + @Test + public void testGetRaw() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + assertEquals("{\"contents\": \"here is some text here is some more text. city.\"}", + searcher.documentRaw(0)); + assertEquals("{\"contents\": \"more texts\"}", + searcher.documentRaw(1)); + assertEquals("{\"contents\": \"here is a test\"}", + searcher.documentRaw(2)); + assertNull(searcher.document(3)); + + assertEquals("{\"contents\": \"here is some text here is some more text. city.\"}", + searcher.documentRaw("doc1")); + assertEquals("{\"contents\": \"more texts\"}", + searcher.documentRaw("doc2")); + assertEquals("{\"contents\": \"here is a test\"}", + searcher.documentRaw("doc3")); + assertNull(searcher.documentContents("doc42")); + + searcher.close(); + } + + @Test + public void testSearch1() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + Map testQuery = new HashMap<>(); + testQuery.put("test", 1.2f); + + SimpleImpactSearcher.Result[] hits = searcher.search(testQuery, 10); + assertEquals(1, hits.length); + assertEquals("doc3", hits[0].docid); + assertEquals(2, hits[0].lucene_docid); + assertEquals(1.2f, hits[0].score, 10e-6); + assertEquals("here is a test", hits[0].contents); + assertEquals("{\"contents\": \"here is a test\"}", hits[0].raw); + + // We can fetch the exact same information from the raw Lucene document also. + assertEquals("doc3", + hits[0].lucene_document.getField(IndexArgs.ID).stringValue()); + assertEquals("here is a test", + hits[0].lucene_document.getField(IndexArgs.CONTENTS).stringValue()); + assertEquals("{\"contents\": \"here is a test\"}", + hits[0].lucene_document.getField(IndexArgs.RAW).stringValue()); + + searcher.close(); + } + + @Test + public void testSearch2() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + + Map testQuery = new HashMap<>(); + testQuery.put("text", 1.2f); + + SimpleImpactSearcher.Result[] results; + + results = searcher.search(testQuery, 1); + assertEquals(1, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(2.4f, results[0].score, 10e-6); + assertEquals("here is some text here is some more text. city.", results[0].contents); + assertEquals("{\"contents\": \"here is some text here is some more text. city.\"}", results[0].raw); + + results = searcher.search(testQuery); + assertEquals(2, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals("doc2", results[1].docid); + assertEquals(1, results[1].lucene_docid); + assertEquals(2.4f, results[0].score, 10e-6); + assertEquals(1.2f, results[1].score, 10e-6); + + Map testQuery2 = new HashMap<>(); + testQuery2.put("test", 0.125f); + + results = searcher.search(testQuery2); + assertEquals(1, results.length); + assertEquals("doc3", results[0].docid); + assertEquals(2, results[0].lucene_docid); + assertEquals(0.125f, results[0].score, 10e-6); + + searcher.close(); + } + + @Test + public void testBatchSearch() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + Map testQuery1 = new HashMap<>(); + testQuery1.put("tests", 0.1f); + testQuery1.put("test", 0.1f); + Map testQuery2 = new HashMap<>(); + testQuery2.put("more", 1.5f); + + List> queries = new ArrayList<>(); + queries.add(testQuery1); + queries.add(testQuery2); + + List qids = new ArrayList<>(); + qids.add("query_test"); + qids.add("query_more"); + + Map hits = searcher.batchSearch(queries, qids, 10, 2); + assertEquals(2, hits.size()); + + assertEquals(1, hits.get("query_test").length); + assertEquals("doc3", hits.get("query_test")[0].docid); + + assertEquals(2, hits.get("query_more").length); + assertEquals("doc1", hits.get("query_more")[0].docid); + assertEquals("doc2", hits.get("query_more")[1].docid); + + searcher.close(); + } + + @Test + public void testTotalNumDocuments() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + assertEquals(3 ,searcher.getTotalNumDocuments()); + } +}