This repository was archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 918
/
Copy pathsenteval.py
67 lines (55 loc) · 2.18 KB
/
senteval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Utilities for evaluating sentence embeddings."""
class SentEvalConfig:
"""Object to store static properties of senteval experiments
Attributes:
model_params (dict): model parameters that stay consistent across all runs
senteval_params (dict): senteval parameters that stay consistent across all runs
"""
def __init__(self, model_params, senteval_params):
"""Summary
Args:
model_params (dict): model parameters that stay consistent across all runs
senteval_params (dict): senteval parameters that stay consistent across all runs
"""
self.model_params = model_params
self.senteval_params = senteval_params
@property
def model_params(self):
return self._model_params
@model_params.setter
def model_params(self, model_params):
self._model_params = model_params
def append_senteval_params(self, params):
"""Util to append any params to senteval_params after initialization"""
self.senteval_params = dict(self.senteval_params, **params)
classifying_tasks = {
"MR",
"CR",
"SUBJ",
"MPQA",
"SST2",
"SST5",
"TREC",
"SICKEntailment",
"SNLI",
"MRPC",
}
if any(t in classifying_tasks for t in self.transfer_tasks):
try:
a = "classifier" in self.senteval_params
if not a:
raise ValueError("Include param['classifier'] to run task {}".format(t))
else:
b = (
set("nhid", "optim", "batch_size", "tenacity", "epoch_size")
in self.senteval_params["classifier"].keys()
)
if not b:
raise ValueError(
"Include nhid, optim, batch_size, tenacity, and epoch_size params to "
"run task {}".format(t)
)
except ValueError as ve:
print(ve)