forked from uhh-lt/sensegram
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpooling.py
executable file
·213 lines (177 loc) · 8.75 KB
/
pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
Reads sense inventory (chinese-whispers format: word<TAB>sense_id<TAB>cluster, where cluster= word:weight,word:weight)
and creates a sense vector for each cluster.
Result - a sensegram model, each sense in form word#sense_id.
If -inventory path is set, also creates a new sense inventory for this sense vector model.
"""
import argparse, codecs
from operator import methodcaller
from collections import defaultdict
import numpy as np
from pandas import read_csv
# from gensim.models import word2vec
from gensim.models import KeyedVectors
import sensegram
import pbar
CHUNK_LINES = 500000
SPLIT_MWE = True
debug = False
sen_delimiter = "#" # python#0, python#1, etc
inventory_header = "word\tsense_id\trel_terms\n"
def file_len(fname):
with open(fname) as f:
for i, l in enumerate(f):
pass
return i + 1
def pool_vectors(vectors, similarities, method):
if method == 'mean':
return np.mean(vectors, axis=0)
if method == 'weighted_mean':
s = sum(similarities)
sim_weights = [sim / s for sim in similarities]
return np.average(vectors, axis=0, weights=sim_weights)
if method == 'ranked':
cluster_size = len(vectors)
rank_weights = [1 / rank for rank in range(1, cluster_size + 1)]
return np.average(vectors, axis=0, weights=rank_weights)
else:
raise ValueError("Unknown pooling method '%s'" % method)
class Dummysink(object):
# dummy object imitates file object but does nothing.
def write(self, data):
pass # ignore the data
def __enter__(self):
return self
def __exit__(self, *x):
return False
def write_inventory(filename):
# If inventory filename is set, open it and write the inventory. Otherwise do nothing.
if filename:
return codecs.open(filename, "w", encoding='utf-8')
else:
return Dummysink()
def initialize(clusters_file, has_header, vector_size):
""" Initialize sense model """
nclusters = file_len(clusters_file)
if has_header:
nclusters = nclusters - 1
senvec = sensegram.SenseGram(size=vector_size, sorted_vocab=0)
senvec.wv.syn0 = np.zeros((nclusters, vector_size), dtype=np.float32)
if debug:
print(("Matrix shape: (%i, %i)" % (nclusters, vector_size)))
return senvec
def read_clusetrs_file(clusters, has_header):
# na_values=[""], keep_default_na=False means that strings 'NaN', 'nan', 'na' etc will be interpreted
# as corresponding strings, not replaced with float NaN.
# doublequote=False, quotechar=u"\u0000" changes quotechar from default '"' to NUL
# otherwise any delimiter inside quotes would be ignores
if has_header:
reader = read_csv(clusters, encoding="utf-8", delimiter="\t", error_bad_lines=False, iterator=True,
chunksize=CHUNK_LINES, na_values=[""], keep_default_na=False,
doublequote=False, quotechar="\u0000", index_col=False)
else:
reader = read_csv(clusters, encoding="utf-8", delimiter="\t", error_bad_lines=False, iterator=True,
chunksize=CHUNK_LINES, na_values=[""], keep_default_na=False,
doublequote=False, quotechar="\u0000",
header=None, names=["word", "cid", "cluster"])
return reader
def parse_cluster(row_cluster, wordvec):
# only pool cluster words which are in the word vector model
# skip words in clusters that cannot be split correctly
cluster = []
for item in row_cluster.split(','):
try:
word, sim = item.strip().rsplit(':', 1)
float(sim) # assert sim string represents a float
if word in wordvec.vocab:
cluster.append((word, sim))
if SPLIT_MWE:
words = word.split(" ")
if len(words) == 1:
continue
for w in words:
if w in wordvec.vocab:
cluster.append((w, sim))
except:
print("Warning: wrong cluster word", item)
return cluster
def run(clusters, model, output, method='weighted', lowercase=False, inventory=None, has_header=True):
small_clusters = 0
sen_count = defaultdict(int) # number of senses per word
cluster_sum = defaultdict(int) # number of cluster words per word
print("Loading original word model...")
wordvec = KeyedVectors.load_word2vec_format(model, binary=True)
print("Initializing sense model...")
senvec = initialize(clusters, has_header, wordvec.syn0.shape[1])
print(("Pooling cluster vectors (%s method)..." % method))
reader = read_clusetrs_file(clusters, has_header)
pb = pbar.Pbar(senvec.wv.syn0.shape[0], 100)
pb.start()
with write_inventory(inventory) as inv_output:
inv_output.write(inventory_header)
i = 0
for chunk in reader:
if debug:
print(("Column types: %s" % chunk.dtypes))
for j, row in chunk.iterrows():
row_word = row.word
row_cluster = row.cluster
if lowercase:
row_cluster = row_cluster.lower()
# enumerate word senses from 0
sen_word = str(row_word) + sen_delimiter + str(sen_count[row_word])
# process new sense
sen_cluster = parse_cluster(row_cluster, wordvec)
if len(sen_cluster) >= 5:
vectors = np.array([wordvec[word] for word, sim in sen_cluster])
sims = np.array([float(sim) for word, sim in sen_cluster])
sen_vector = pool_vectors(vectors, sims, method)
if sen_word not in senvec.wv.vocab:
senvec.add_word(sen_word, sen_vector)
senvec.wv.probs[sen_word] = len(sen_cluster) # number of cluster words per sense
sen_count[row_word] += 1 # number of senses per word
cluster_sum[row_word] += len(sen_cluster) # number of cluster words per word
# write new sense to sense inventory
if inventory:
# join back cluster words (only those that were actually used for sense vector)
cluster = ",".join([word + ":" + sim for word, sim in sen_cluster])
inv_output.write("%s\t%s\t%s\n" % (sen_word.split(sen_delimiter)[0],
sen_word.split(sen_delimiter)[1], cluster))
else:
small_clusters += 1
if debug:
print(row_word, "\t", row.cid)
print(sen_cluster)
pb.update(i)
i += 1
senvec.__normalize_probs__(cluster_sum)
pb.finish()
##### Validation #####
if senvec.wv.syn0.shape[0] != len(senvec.wv.vocab):
print(("Shrinking matrix size from %i to %i" % (senvec.wv.syn0.shape[0], len(senvec.wv.vocab))))
senvec.wv.syn0 = np.ascontiguousarray(senvec.wv.syn0[:len(senvec.wv.vocab)])
print(("Sense vectors saved to: " + output))
senvec.save_word2vec_format(fname=output, binary=False)
def main():
parser = argparse.ArgumentParser(description='Create sense vectors based on sense clusters and word vectors.')
parser.add_argument('clusters',
help='A path to an input file with postprocessed clusters and a header. Format: "word<TAB>cid<TAB>cluster" where <cluster> is "word:sim,word:sim,..."')
parser.add_argument('model', help='A path to an initial word vector model')
parser.add_argument('output', help='A path to the output sense vector model')
parser.add_argument('-method',
help="A method used to pool word vectors into a sense vector ('mean', 'weighted_mean', 'ranked'). Default 'weighted_mean'",
default='weighted_mean')
parser.add_argument("-lowercase",
help="Lowercase all words in clusters (necessary if word model only has lowercased words). Default False",
action="store_true")
parser.add_argument("-inventory",
help='A path to the output inventory file of computed sense vector model with a header. Format: "word<TAB>sense_id<TAB>rel_terms" where <rel_terms> is "word:sim,word:sim,...". If not given, inventory is not written. Default None',
default=None)
parser.add_argument('-no_header', action='store_true', help='No headers in cluster file. Default -- false.')
args = parser.parse_args()
run(args.clusters, args.model, args.output, args.method, args.lowercase, args.inventory,
has_header=(not args.no_header))
if __name__ == '__main__':
main()