-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnearestneighbors.py
130 lines (112 loc) · 4.3 KB
/
nearestneighbors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
from annoy import AnnoyIndex
import os
import sys
import numpy as np
from tqdm import tqdm
import pickle
import ngtpy ##Pybind11
#from ngt import base as ngt ##Undefined Symbols/Ctypes
from tqdm import tqdm
import time
import progressbar
# Usage:
# in terminal: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
# Manually run this file once to create the graph index.
# Call the static method nearestneighbors to search for neighbors
# to a vector.
# vector: 1D array-like object of length 300.
# number: integer.
# output: list object containing <number> words corresponding to
# the <number> closest vectors in the embedding to <vector>,
# according to angular distance/cosine similarity.
def print_nearest(word):
for idx in index.get_nns_by_vector(test_vectors[word],10):
print(words[idx])
def print_neighbors(vector, number):
for idx in index.get_nns_by_vector(test_vectors[word],10):
print(words[idx])
# Query Usage:
# Uses annoy index to find nearestneighbors of a vector.
# vector: 1D array-like object of length 300 corresponding to the word to be changed
# factor: How many different words to substitute per index.
# words: a list of words from the index
# Outputs neighbor words and their corresponding vectors
def query(vector, factor, speech = True):
# filepath = "data/glove.840B.300d.txt" #hard code
idx = 0
if speech == True:
index = AnnoyIndex(300, 'euclidean')
index.load('data/Speech.ann')
else:
index = AnnoyIndex(300, 'euclidean')
index.load('data/840b.ann')
outputw = []
outputv = []
if speech == True:
f = open('data/speechwords.txt', 'r')
else:
f = open('data/words.txt', 'r')
# f = open('speech.txt', 'r')
try:
words = f.readlines()
except:
# print("oh no unicode")
if speech == True:
f = open('data/speechwords.txt', 'r', encoding="utf-8")
else:
f = open('data/words.txt', 'r', encoding="utf-8")
words = f.readlines()
f.close()
for idx in index.get_nns_by_vector(vector, factor, search_k=100):
#print("Nearest Neighbor word ", idx, ":", words[idx].encode("utf-8"))
outputw.append(words[idx])
outputv.append(index.get_item_vector(idx))
#print("nn query complete")
#print("vectors: ", outputv, "words: ", outputw)
return outputw, outputv
def main():
pass
def annoy_build(word_vectors, speech=True):
idx = 0
if speech == True:
index = AnnoyIndex(300, 'euclidean')
else:
index = AnnoyIndex(300, 'euclidean')
###########################################################################################
### In this case we take word_vectors which looks like ###
### {("dog", 1.00 4.22 3.95...), ("cat", 4.40 2.22 7.33...), ("kind", 6.32 3.45...)} ###
### and outputs words= [dog, cat, kind] and vectors[1.00 4.22 3.95..., 4.40 2.22 7.33..]###
###########################################################################################
words, vectors = zip(*word_vectors.items())
if speech == True:
num_lines = sum(1 for line in open('data/speechwords.txt', 'r', encoding='utf-8'))
with open('data/speechwords.txt', 'w', encoding='utf-8') as out:
with tqdm(total=num_lines) as pbar:
pbar.set_description("Writing index")
for item in words:
out.write("%s\n" % item)
pbar.update(1)
pbar.close()
else:
num_lines = sum(1 for line in open('data/words.txt', 'r', encoding='utf-8'))
with open('data/words.txt', 'w', encoding='utf-8') as out:
with tqdm(total=num_lines) as pbar:
pbar.set_description("Writing index")
for item in words:
out.write("%s\n" % item)
pbar.update(1)
pbar.close()
#### Then we add the vectors to annoy index for building
for idx, vector in enumerate(vectors):
index.add_item(idx, vector)
if speech == True:
index.build(100) # trees
index.save('data/Speech.ann')
else:
index.build(100) # trees
index.save('data/840b.ann')
print("index saved")
return words
if __name__ == "__main__":
pass