-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglove.py
52 lines (33 loc) · 1.45 KB
/
glove.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from NLP4FinTools.glove.utils import load_glove_embeddings
from nltk.corpus import stopwords
class Glove:
def __init__(self, glovepath, cachepath, limit=100000):
""" Initialize glove class
Args:
glovepath (string): Path for pretrained glove embeddings
cachepath (string): Path to store embedding pickle cache
limit (int): Maximum size of vocabulary
"""
# Load embeddings
self.terms, self.embeds = load_glove_embeddings(glovepath, cachepath, limit=limit)
self.tdict = { self.terms[idx]: idx for idx in range(self.terms.shape[0]) }
# Load stopwords and add capitalized versions of stopwords
stops = stopwords.words('english')
self.stops = set(stops + [stop[0].upper() + stop[1:] for stop in stops])
def __contains__(self, term):
""" Check if term is in word embedding set
Args:
term (string): Term to check
Returns:
(bool): Whether term exists in vocabulary
"""
contains = True if term in self.tdict else False
return contains
def __getitem__(self, term):
""" Return embedding for term if its in the vocabulary
Args:
term (string): Term embedding to return
Return:
(numpy.ndarray): Numpy array of word embedding
"""
return self.embeds[self.tdict[term]]