Skip to content

Commit

Permalink
Massive LTR code drop! (#1454)
Browse files Browse the repository at this point in the history
+ add cmd line interface
+ add more features implementations in ltr (pre-retrieval, ibm, entity)
+ refactoring document field context and query field context
  • Loading branch information
stephaniewhoo authored Jan 18, 2021
1 parent 9d885d3 commit 1691d09
Show file tree
Hide file tree
Showing 105 changed files with 6,729 additions and 2,783 deletions.
4 changes: 4 additions & 0 deletions src/main/java/io/anserini/index/IndexArgs.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ public class IndexArgs {
// This is the name of the field in the Lucene document where the raw document is stored.
public static final String RAW = "raw";

// This is the name of the field in the Lucene document where the entity document is stored.
public static final String ENTITY = "entity";


private static final int TIMEOUT = 600 * 1000;

// required arguments
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/io/anserini/index/IndexCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.lucene.analysis.ar.ArabicAnalyzer;
import org.apache.lucene.analysis.bn.BengaliAnalyzer;
import org.apache.lucene.analysis.cjk.CJKAnalyzer;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.de.GermanAnalyzer;
import org.apache.lucene.analysis.es.SpanishAnalyzer;
import org.apache.lucene.analysis.fr.FrenchAnalyzer;
Expand Down Expand Up @@ -749,6 +750,7 @@ public Counters run() throws IOException {
final BengaliAnalyzer bengaliAnalyzer = new BengaliAnalyzer();
final GermanAnalyzer germanAnalyzer = new GermanAnalyzer();
final SpanishAnalyzer spanishAnalyzer = new SpanishAnalyzer();
final WhitespaceAnalyzer whitespaceAnalyzer = new WhitespaceAnalyzer();
final DefaultEnglishAnalyzer analyzer;
if (args.keepStopwords) {
analyzer = DefaultEnglishAnalyzer.newStemmingInstance(args.stemmer, CharArraySet.EMPTY_SET);
Expand Down Expand Up @@ -778,6 +780,8 @@ public Counters run() throws IOException {
config = new IndexWriterConfig(germanAnalyzer);
} else if (args.language.equals("es")) {
config = new IndexWriterConfig(spanishAnalyzer);
} else if (args.language.equals("en_ws")) {
config = new IndexWriterConfig(whitespaceAnalyzer);
} else {
config = new IndexWriterConfig(analyzer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ public Document createDocument(T src) throws GeneratorException {
// Currently we just use all the settings of the main "content" field.
if (src instanceof MultifieldSourceDocument) {
((MultifieldSourceDocument) src).fields().forEach((k, v) -> {
document.add(new Field(k, v, fieldType));
if (k == IndexArgs.ENTITY) {
document.add(new StoredField(IndexArgs.ENTITY, v));
} else {
document.add(new Field(k, v, fieldType));
}
});
}

Expand Down
220 changes: 179 additions & 41 deletions src/main/java/io/anserini/ltr/FeatureExtractorCli.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
/*
* Anserini: A Lucene toolkit for replicable information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.anserini.ltr;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.anserini.ltr.feature.OrderedSequentialPairsFeatureExtractor;
import io.anserini.ltr.feature.UnorderedSequentialPairsFeatureExtractor;
import io.anserini.ltr.feature.base.*;
import io.anserini.ltr.feature.*;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
Expand All @@ -30,6 +44,106 @@ static class DebugArgs {
public int threads = 1;

}
public static void addFeature(FeatureExtractorUtils utils, String queryField, String docField) throws IOException {
utils.add(new BM25(0.9,0.4, docField, queryField));
utils.add(new BM25(1.2,0.75, docField, queryField));
utils.add(new BM25(2.0,0.75, docField, queryField));

utils.add(new LMDir(1000, docField, queryField));
utils.add(new LMDir(1500, docField, queryField));
utils.add(new LMDir(2500, docField, queryField));

utils.add(new LMJM(0.1, docField, queryField));
utils.add(new LMJM(0.4, docField, queryField));
utils.add(new LMJM(0.7, docField, queryField));

utils.add(new NTFIDF(docField, queryField));
utils.add(new ProbalitySum(docField, queryField));

utils.add(new DFR_GL2(docField, queryField));
utils.add(new DFR_In_expB2(docField, queryField));
utils.add(new DPH(docField, queryField));

utils.add(new Proximity(docField, queryField));
utils.add(new TPscore(docField, queryField));
utils.add(new tpDist(docField, queryField));

utils.add(new DocSize(docField));
utils.add(new Entropy(docField));

utils.add(new QueryLength(queryField));
utils.add(new QueryCoverageRatio(docField, queryField));

utils.add(new UniqueTermCount(queryField));
utils.add(new MatchingTermCount(docField, queryField));
utils.add(new SCS(docField, queryField));

utils.add(new tfStat(new AvgPooler(), docField, queryField));
utils.add(new tfStat(new MedianPooler(), docField, queryField));
utils.add(new tfStat(new SumPooler(), docField, queryField));
utils.add(new tfStat(new MinPooler(), docField, queryField));
utils.add(new tfStat(new MaxPooler(), docField, queryField));
utils.add(new tfStat(new VarPooler(), docField, queryField));
utils.add(new tfStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new tfStat(new ConfidencePooler(), docField, queryField));
utils.add(new tfIdfStat(new AvgPooler(), docField, queryField));
utils.add(new tfIdfStat(new MedianPooler(), docField, queryField));
utils.add(new tfIdfStat(new SumPooler(), docField, queryField));
utils.add(new tfIdfStat(new MinPooler(), docField, queryField));
utils.add(new tfIdfStat(new MaxPooler(), docField, queryField));
utils.add(new tfIdfStat(new VarPooler(), docField, queryField));
utils.add(new tfIdfStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new tfIdfStat(new ConfidencePooler(), docField, queryField));
utils.add(new scqStat(new AvgPooler(), docField, queryField));
utils.add(new scqStat(new MedianPooler(), docField, queryField));
utils.add(new scqStat(new SumPooler(), docField, queryField));
utils.add(new scqStat(new MinPooler(), docField, queryField));
utils.add(new scqStat(new MaxPooler(), docField, queryField));
utils.add(new scqStat(new VarPooler(), docField, queryField));
utils.add(new scqStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new scqStat(new ConfidencePooler(), docField, queryField));
utils.add(new normalizedTfStat(new AvgPooler(), docField, queryField));
utils.add(new normalizedTfStat(new MedianPooler(), docField, queryField));
utils.add(new normalizedTfStat(new SumPooler(), docField, queryField));
utils.add(new normalizedTfStat(new MinPooler(), docField, queryField));
utils.add(new normalizedTfStat(new MaxPooler(), docField, queryField));
utils.add(new normalizedTfStat(new VarPooler(), docField, queryField));
utils.add(new normalizedTfStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new normalizedTfStat(new ConfidencePooler(), docField, queryField));

utils.add(new idfStat(new AvgPooler(), docField, queryField));
utils.add(new idfStat(new MedianPooler(), docField, queryField));
utils.add(new idfStat(new SumPooler(), docField, queryField));
utils.add(new idfStat(new MinPooler(), docField, queryField));
utils.add(new idfStat(new MaxPooler(), docField, queryField));
utils.add(new idfStat(new VarPooler(), docField, queryField));
utils.add(new idfStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new idfStat(new ConfidencePooler(), docField, queryField));
utils.add(new ictfStat(new AvgPooler(), docField, queryField));
utils.add(new ictfStat(new MedianPooler(), docField, queryField));
utils.add(new ictfStat(new SumPooler(), docField, queryField));
utils.add(new ictfStat(new MinPooler(), docField, queryField));
utils.add(new ictfStat(new MaxPooler(), docField, queryField));
utils.add(new ictfStat(new VarPooler(), docField, queryField));
utils.add(new ictfStat(new MaxMinRatioPooler(), docField, queryField));
utils.add(new ictfStat(new ConfidencePooler(), docField, queryField));

utils.add(new UnorderedSequentialPairs(3, docField, queryField));
utils.add(new UnorderedSequentialPairs(8, docField, queryField));
utils.add(new UnorderedSequentialPairs(15, docField, queryField));
utils.add(new OrderedSequentialPairs(3, docField, queryField));
utils.add(new OrderedSequentialPairs(8, docField, queryField));
utils.add(new OrderedSequentialPairs(15, docField, queryField));
utils.add(new UnorderedQueryPairs(3, docField, queryField));
utils.add(new UnorderedQueryPairs(8, docField, queryField));
utils.add(new UnorderedQueryPairs(15, docField, queryField));
utils.add(new OrderedQueryPairs(3, docField, queryField));
utils.add(new OrderedQueryPairs(8, docField, queryField));
utils.add(new OrderedQueryPairs(15, docField, queryField));


}

public static void main(String[] args) throws IOException, ExecutionException, InterruptedException {
long start = System.nanoTime();
DebugArgs cmdArgs = new DebugArgs();
Expand All @@ -44,30 +158,34 @@ public static void main(String[] args) throws IOException, ExecutionException, I
}

FeatureExtractorUtils utils = new FeatureExtractorUtils(cmdArgs.indexDir, cmdArgs.threads);
addFeature(utils,"analyzed","contents");
//addFeature(utils,"text","text");
//addFeature(utils,"text_unlemm","text_unlemm");
//addFeature(utils,"text_bert_tok","text_bert_tok");
// utils.add(new IBMModel1("../pyserini/collections/msmarco-passage/text_bert_tok","Bert","BERT","text_bert_tok"));

utils.add(new EntityHowMany());
utils.add(new EntityHowMuch());
utils.add(new EntityHowLong());

utils.add(new EntityWho());
utils.add(new EntityWhen());
utils.add(new EntityWhere());

utils.add(new EntityWhoMatch());
utils.add(new EntityWhereMatch());

utils.add(new EntityQueryCount("PERSON"));
utils.add(new EntityDocCount("PERSON"));

utils.add(new QueryRegex("^[0-9.+_ ]*what.*$"));

utils.add(new AvgICTFFeatureExtractor());
utils.add(new AvgIDFFeatureExtractor());
utils.add(new BM25FeatureExtractor());
utils.add(new DocSizeFeatureExtractor());
utils.add(new MatchingTermCount());
utils.add(new PMIFeatureExtractor());
utils.add(new QueryLength());
utils.add(new SCQFeatureExtractor());
utils.add(new SCSFeatureExtractor());
utils.add(new SumMatchingTF());
utils.add(new TFIDFFeatureExtractor());
utils.add(new UniqueTermCount());
utils.add(new UnorderedSequentialPairsFeatureExtractor(3));
utils.add(new UnorderedSequentialPairsFeatureExtractor(5));
utils.add(new UnorderedSequentialPairsFeatureExtractor(8));
utils.add(new OrderedSequentialPairsFeatureExtractor(3));
utils.add(new OrderedSequentialPairsFeatureExtractor(5));
utils.add(new OrderedSequentialPairsFeatureExtractor(8));

File file = new File(cmdArgs.jsonFile);
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String line;
List<String> qids = new ArrayList<>();
int lineNum = 0;
int offset = 0;
String lastQid = null;
ObjectMapper mapper = new ObjectMapper();
Expand All @@ -77,42 +195,62 @@ public static void main(String[] args) throws IOException, ExecutionException, I
time[i] = 0;
}
long executionStart = System.nanoTime();
while((line=reader.readLine())!=null){
qids.add(utils.lazyExtract(line));
if(qids.size()>=100){
while((line=reader.readLine())!=null&&offset<10000){
lineNum++;
// if(lineNum<=760) continue;
qids.add(utils.debugExtract(line));
if(qids.size()>=10){
try{
while(qids.size()>0) {
lastQid = qids.remove(0);
String allResult = utils.getResult(lastQid);
TypeReference<ArrayList<output>> typeref = new TypeReference<>() {};
List<output> outputArray = mapper.readValue(allResult, typeref);
for(output res:outputArray){
List<debugOutput> outputArray = utils.getDebugResult(lastQid);
// System.out.println(String.format("Qid:%s\tLine:%d",lastQid,offset));
for(debugOutput res:outputArray){
for(int i = 0; i < names.size(); i++){
time[i] += res.time.get(i);
}
}
offset++;
}
} catch (Exception e) {
System.out.println("the offset is:"+offset+"at qid:"+lastQid);
System.out.println("the offset is:"+offset+" at qid:"+lastQid);
throw e;
}
}

}
long executionEnd = System.nanoTime();
long sumtime = 0;
for(int i = 0; i < names.size(); i++){
sumtime += time[i];
}
for(int i = 0; i < names.size(); i++){
System.out.println(names.get(i)+" takes "+String.format("%.2f",time[i]/1000000000.0) + "s, accounts for "+ String.format("%.2f", time[i]*100.0/sumtime) + "%");
if(qids.size()>=0){
try{
while(qids.size()>0) {
lastQid = qids.remove(0);
List<debugOutput> outputArray = utils.getDebugResult(lastQid);
// System.out.println(String.format("Qid:%s\tLine:%d",lastQid,offset));
for(debugOutput res:outputArray){
for(int i = 0; i < names.size(); i++){
time[i] += res.time.get(i);
}
}
offset++;
}
} catch (Exception e) {
System.out.println("the offset is:"+offset+"at qid:"+lastQid);
throw e;
}
}
// long executionEnd = System.nanoTime();
// long sumtime = 0;
// for(int i = 0; i < names.size(); i++){
// sumtime += time[i];
// }
// for(int i = 0; i < names.size(); i++){
// System.out.println(names.get(i)+" takes "+String.format("%.2f",time[i]/1000000000.0) + "s, accounts for "+ String.format("%.2f", time[i]*100.0/sumtime) + "%");
// }
utils.close();
reader.close();

long end = System.nanoTime();
long overallTime = end - start;
long overhead = overallTime-(executionEnd - executionStart);
System.out.println("The program takes "+String.format("%.2f",overallTime/1000000000.0) + "s, where the overhead takes " + String.format("%.2f",overhead/1000000000.0) +"s");
//
// long end = System.nanoTime();
// long overallTime = end - start;
// long overhead = overallTime-(executionEnd - executionStart);
// System.out.println("The program takes "+String.format("%.2f",overallTime/1000000000.0) + "s, where the overhead takes " + String.format("%.2f",overhead/1000000000.0) +"s");
}
}
Loading

0 comments on commit 1691d09

Please # to comment.