This repository has been archived by the owner on Aug 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
random coefficient search for rescoring (#520)
Summary: Pull Request resolved: #520 Random trials to find the best weights for mixing rescoring models over a provided tune set (pickled scores exported by rescorer.py). Reviewed By: akinh Differential Revision: D15254265 fbshipit-source-id: 7dc6fbdb6afb40a39bdde86bb04b340240e1fc02
- Loading branch information
1 parent
4fb0560
commit d4a57cb
Showing
1 changed file
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import pickle | ||
|
||
import numpy as np | ||
import torch | ||
from fairseq import bleu | ||
from pytorch_translate.dictionary import Dictionary | ||
from pytorch_translate.generate import smoothed_sentence_bleu | ||
|
||
|
||
def get_arg_parser(): | ||
parser = argparse.ArgumentParser( | ||
description=("Rescore generated hypotheses with extra models") | ||
) | ||
parser.add_argument( | ||
"--scores-info-export-path", type=str, help="Model scores for weights search" | ||
) | ||
parser.add_argument( | ||
"--num-trials", | ||
type=int, | ||
default=1000, | ||
help="Number of iterations of random search", | ||
) | ||
parser.add_argument("--evaluate-oracle-bleu", default=False, action="store_true") | ||
parser.add_argument("--report-oracle-bleu", default=False, action="store_true") | ||
parser.add_argument( | ||
"--report-intermediate-results", default=False, action="store_true" | ||
) | ||
return parser | ||
|
||
|
||
class DummyTask: | ||
""" | ||
Default values for pad, eos, unk | ||
""" | ||
|
||
def __init__(self): | ||
self.target_dictionary = Dictionary() | ||
|
||
|
||
def evaluate_weights(scores_info, feature_weights, length_penalty): | ||
pad, eos, unk = (0, 1, 2) | ||
scorer = bleu.Scorer(pad, eos, unk) | ||
|
||
for example in scores_info: | ||
weighted_scores = (example["scores"] * feature_weights).sum(axis=1) | ||
weighted_scores /= (example["tgt_len"] ** length_penalty) + 1e-12 | ||
top_hypo_ind = np.argmax(weighted_scores) | ||
top_hypo = example["hypos"][top_hypo_ind] | ||
ref = example["target_tokens"] | ||
scorer.add(torch.IntTensor(ref), torch.IntTensor(top_hypo)) | ||
|
||
return scorer.score() | ||
|
||
|
||
def identify_nonzero_features(scores_info): | ||
nonzero_features = np.any(scores_info[0]["scores"] != 0, axis=0) | ||
for example in scores_info[1:]: | ||
nonzero_features |= np.any(example["scores"] != 0, axis=0) | ||
|
||
return np.where(nonzero_features)[0] | ||
|
||
|
||
def random_search( | ||
scores_info_export_path, | ||
num_trials, | ||
report_oracle_bleu=False, | ||
report_intermediate_results=False, | ||
): | ||
with open(scores_info_export_path, "rb") as f: | ||
scores_info = pickle.load(f) | ||
|
||
dummy_task = DummyTask() | ||
|
||
if report_oracle_bleu: | ||
pad, eos, unk = (0, 1, 2) | ||
oracle_scorer = bleu.Scorer(pad, eos, unk) | ||
|
||
for example in scores_info: | ||
smoothed_bleu = [] | ||
for hypo in example["hypos"]: | ||
eval_score = smoothed_sentence_bleu( | ||
dummy_task, | ||
torch.IntTensor(example["target_tokens"]), | ||
torch.IntTensor(hypo), | ||
) | ||
smoothed_bleu.append(eval_score) | ||
best_hypo_ind = np.argmax(smoothed_bleu) | ||
example["best_hypo_ind"] = best_hypo_ind | ||
|
||
oracle_scorer.add( | ||
torch.IntTensor(example["target_tokens"]), | ||
torch.IntTensor(example["hypos"][best_hypo_ind]), | ||
) | ||
|
||
print("oracle BLEU", oracle_scorer.score()) | ||
|
||
num_features = scores_info[0]["scores"].shape[1] | ||
assert all( | ||
example["scores"].shape[1] == num_features for example in scores_info | ||
), "All examples must have the same number of scores!" | ||
feature_weights = np.zeros(num_features) | ||
feature_weights[0] = 1 | ||
score = evaluate_weights(scores_info, feature_weights, length_penalty=0) | ||
best_score = score | ||
best_weights = feature_weights | ||
best_length_penalty = 0 | ||
|
||
nonzero_features = identify_nonzero_features(scores_info) | ||
|
||
for i in range(num_trials): | ||
feature_weights = np.zeros(num_features) | ||
random_weights = np.random.dirichlet(np.ones(nonzero_features.size)) | ||
feature_weights[nonzero_features] = random_weights | ||
length_penalty = 1.5 * np.random.random() | ||
|
||
score = evaluate_weights(scores_info, feature_weights, length_penalty) | ||
if score > best_score: | ||
best_score = score | ||
best_weights = feature_weights | ||
best_length_penalty = length_penalty | ||
|
||
if report_intermediate_results: | ||
print(f"\r[{i}] best: {best_score}", end="", flush=True) | ||
|
||
if report_intermediate_results: | ||
print() | ||
print("best BLEU: ", best_score) | ||
print("best weights: ", best_weights) | ||
print("best length penalty: ", length_penalty) | ||
|
||
return best_weights, best_length_penalty, best_score | ||
|
||
|
||
def main(): | ||
args = get_arg_parser().parse_args() | ||
|
||
assert ( | ||
args.scores_info_export_path is not None | ||
), "--scores-info-export-path is required for weights search" | ||
|
||
random_search( | ||
args.scores_info_export_path, | ||
args.num_trials, | ||
args.report_oracle_bleu, | ||
args.report_intermediate_results, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |