diff --git a/src/main/java/io/anserini/search/topicreader/Topics.java b/src/main/java/io/anserini/search/topicreader/Topics.java index 598c20c808..65037fc067 100755 --- a/src/main/java/io/anserini/search/topicreader/Topics.java +++ b/src/main/java/io/anserini/search/topicreader/Topics.java @@ -57,11 +57,13 @@ public enum Topics { TREC2019_DL_PASSAGE_UNICOIL(TsvIntTopicReader.class,"topics.dl19-passage.unicoil.0shot.tsv.gz"), TREC2019_DL_PASSAGE_UNICOIL_NOEXP(TsvIntTopicReader.class,"topics.dl19-passage.unicoil-noexp.0shot.tsv.gz"), TREC2019_DL_PASSAGE_SPLADE_DISTILL_COCODENSER_MEDIUM(TsvIntTopicReader.class,"topics.dl19-passage.splade_distil_cocodenser_medium.tsv.gz"), + TREC2019_DL_PASSAGE_COS_DPR_DISTIL(JsonIntVectorTopicReader.class, "topics.dl19-passage.cos-dpr-distil.jsonl.gz"), TREC2020_DL(TsvIntTopicReader.class,"topics.dl20.txt"), TREC2020_DL_WP(TsvIntTopicReader.class,"topics.dl20.wp.tsv.gz"), TREC2020_DL_UNICOIL(TsvIntTopicReader.class,"topics.dl20.unicoil.0shot.tsv.gz"), TREC2020_DL_UNICOIL_NOEXP(TsvIntTopicReader.class,"topics.dl20.unicoil-noexp.0shot.tsv.gz"), TREC2020_DL_SPLADE_DISTILL_COCODENSER_MEDIUM(TsvIntTopicReader.class,"topics.dl20.splade_distil_cocodenser_medium.tsv.gz"), + TREC2020_DL_COS_DPR_DISTIL(JsonIntVectorTopicReader.class, "topics.dl20.cos-dpr-distil.jsonl.gz"), TREC2021_DL(TsvIntTopicReader.class,"topics.dl21.txt"), TREC2021_DL_UNICOIL(TsvIntTopicReader.class,"topics.dl21.unicoil.0shot.tsv.gz"), TREC2021_DL_UNICOIL_NOEXP(TsvIntTopicReader.class,"topics.dl21.unicoil-noexp.0shot.tsv.gz"), @@ -87,6 +89,7 @@ public enum Topics { MSMARCO_PASSAGE_DEV_SUBSET_UNICOIL_TILDE(TsvIntTopicReader.class, "topics.msmarco-passage.dev-subset.unicoil-tilde-expansion.tsv.gz"), MSMARCO_PASSAGE_DEV_SUBSET_DISTILL_SPLADE_MAX(TsvIntTopicReader.class, "topics.msmarco-passage.dev-subset.distill-splade-max.tsv.gz"), MSMARCO_PASSAGE_DEV_SUBSET_SPLADE_DISTILL_COCODENSER_MEDIUM(TsvIntTopicReader.class, "topics.msmarco-passage.dev-subset.splade_distil_cocodenser_medium.tsv.gz"), + MSMARCO_PASSAGE_DEV_SUBSET_COS_DPR_DISTIL(JsonIntVectorTopicReader.class, "topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.gz"), MSMARCO_PASSAGE_TEST_SUBSET(TsvIntTopicReader.class, "topics.msmarco-passage.test-subset.txt"), // MS MARCO V2 topics diff --git a/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java b/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java index 664e88ee67..73a778ee81 100755 --- a/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java +++ b/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java @@ -38,7 +38,7 @@ public void testIterateThroughAllEnums() { String path = topic.path; assertEquals(topic.readerClass, TopicReader.getTopicReaderClassByFile(path)); } - assertEquals(356, cnt); + assertEquals(359, cnt); } @Test @@ -713,6 +713,14 @@ public void testTREC19DL() throws IOException { assertEquals(1890, topics.get(topics.firstKey()).get("title").split(" ").length); assertEquals(1133167, (int) topics.lastKey()); assertEquals(1382, topics.get(topics.lastKey()).get("title").split(" ").length); + + topics = TopicReader.getTopics(Topics.TREC2019_DL_PASSAGE_COS_DPR_DISTIL); + assertNotNull(topics); + assertEquals(43, topics.size()); + assertEquals(19335, (int) topics.firstKey()); + assertEquals("[0.013790097087621689", topics.get(topics.firstKey()).get("vector").split(",")[0]); + assertEquals(1133167, (int) topics.lastKey()); + assertEquals("[-0.024115752428770065", topics.get(topics.lastKey()).get("vector").split(",")[0]); } @Test @@ -760,6 +768,14 @@ public void testTREC20DL() throws IOException { assertEquals(2168, topics.get(topics.firstKey()).get("title").split(" ").length); assertEquals(1136962, (int) topics.lastKey()); assertEquals(2075, topics.get(topics.lastKey()).get("title").split(" ").length); + + topics = TopicReader.getTopics(Topics.TREC2020_DL_COS_DPR_DISTIL); + assertNotNull(topics); + assertEquals(200, topics.size()); + assertEquals(3505, (int) topics.firstKey()); + assertEquals("[0.0012954670237377286", topics.get(topics.firstKey()).get("vector").split(",")[0]); + assertEquals(1136962, (int) topics.lastKey()); + assertEquals("[0.06602190434932709", topics.get(topics.lastKey()).get("vector").split(",")[0]); } @Test @@ -964,6 +980,14 @@ public void testMSMARCO_V1() throws IOException { assertEquals("term service agreement definition", topics.get(topics.firstKey()).get("title")); assertEquals(1136966, (int) topics.lastKey()); assertEquals("#ffffff color code", topics.get(topics.lastKey()).get("title")); + + topics = TopicReader.getTopics(Topics.MSMARCO_PASSAGE_DEV_SUBSET_COS_DPR_DISTIL); + assertNotNull(topics); + assertEquals(6980, topics.size()); + assertEquals(2, (int) topics.firstKey()); + assertEquals("[-0.007401271723210812", topics.get(topics.firstKey()).get("vector").split(",")[0]); + assertEquals(1102400, (int) topics.lastKey()); + assertEquals("[0.05193052813410759", topics.get(topics.lastKey()).get("vector").split(",")[0]); } @Test diff --git a/tools b/tools index 10f1c7b1c6..95d06f6004 160000 --- a/tools +++ b/tools @@ -1 +1 @@ -Subproject commit 10f1c7b1c6c456a3d906bd978b1956c4d6806cf9 +Subproject commit 95d06f60043837a309331ffdbee7560dd1676313