-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
93 lines (72 loc) · 3.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import os
from utils import combine_prob_text, calculate_sequence_loss, optimise_ensemble_weights
def plot_cumulative_state(df_line: pd.DataFrame, df_pie: pd.DataFrame, outfile: str, dataset_name: str):
fig_line_chart = px.line(
df_line,
x="Model",
y="Perplexity",
title="Perplexity of best models on " + dataset_name+ " dataset",
)
fig_pie_chart = px.pie(
df_pie,
values='Weight',
names='Model',
title='Weights of a models in ensemble on ' + dataset_name + ' dataset',
)
with open(outfile, 'a') as f:
f.write(fig_line_chart.to_html(full_html=False, include_plotlyjs='cdn', default_height="70%", default_width="70%"))
f.write(fig_pie_chart.to_html(full_html=False, include_plotlyjs='cdn', default_height="70%", default_width="70%"))
if __name__ == "__main__":
if os.path.exists("index.html"):
os.remove("index.html")
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
rel_paths = ['PennTreeBank', 'WikiText-2', 'WikiText-103']
for path in rel_paths:
print('\n' + '-' * 50)
print('\n' + path)
path = os.path.relpath(path)
val_files = sorted(os.listdir(os.path.join(path, 'valid')))
test_files = sorted(os.listdir(os.path.join(path, 'test')))
val_files_parsed = [f.replace('.txt', '') for f in val_files]
test_files_parsed = [f.replace('.txt', '') for f in test_files]
assert val_files_parsed == test_files_parsed, 'Different names for validation and test files'
val_probabilities = np.vstack(
[combine_prob_text(os.path.join(path, 'valid', file_name)) for file_name in val_files])
test_probabilities = np.vstack(
[combine_prob_text(os.path.join(path, 'test', file_name)) for file_name in test_files])
print("\nIndividual valid ppl of models")
for name, i in zip(val_files, val_probabilities):
# skip unigram cache
if "unigram" in name:
continue
print(name + ": " + str(round(calculate_sequence_loss(i)[1], 2)))
#list for storing individual loss on test set
test_los_individual = []
print("\nIndividual test ppl of models")
for name, i in zip(test_files, test_probabilities):
# skip unigram cache
if "unigram" in name:
continue
test_los_individual.append(calculate_sequence_loss(i)[1])
print(name + ": " + str(round(calculate_sequence_loss(i)[1], 2)))
weights = optimise_ensemble_weights(val_probabilities)
val_file_prob = (weights[:, np.newaxis] * val_probabilities).sum(axis=0)
test_file_prob = (weights[:, np.newaxis] * test_probabilities).sum(axis=0)
val_loss, val_ppl = calculate_sequence_loss(val_file_prob)
test_loss, test_ppl = calculate_sequence_loss(test_file_prob)
test_los_individual.append(test_ppl)
print('\nValidation Perplexity: ', val_ppl)
print('Test Perplexity: ', test_ppl)
print("\nName of files with weights")
for name, w in zip(test_files_parsed, weights):
print(name + ': ' + str(round(w, 2)))
df_line = pd.DataFrame(list(zip(test_files_parsed+['Ensemble of All'], test_los_individual)),
columns=['Model', 'Perplexity'])
df_line = df_line.sort_values(by=['Perplexity'], ascending=False)
df_pie = pd.DataFrame(list(zip(test_files_parsed, weights)),
columns=['Model', 'Weight'])
plot_cumulative_state(df_line, df_pie, "index.html", path)