-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathword2vec_utils.cpp
95 lines (92 loc) · 3.01 KB
/
word2vec_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Create binary Huffman tree using the word counts
// Frequent words will have short uniqe binary codes
// 构建 Huffman 树
void CreateBinaryTree() {
long long a, b, i, min1i, min2i, pos1, pos2, point[MAX_CODE_LENGTH];
char code[MAX_CODE_LENGTH];
long long *count = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
long long *binary = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
long long *parent_node = (long long *)calloc(vocab_size * 2 + 1, sizeof(long long));
// count 数组 放 元素的频数,从高到低 排列
for (a = 0; a < vocab_size; a++) count[a] = vocab[a].cn;
// count 数组两倍于原数组,后面的部分放生成的非叶子节点
for (a = vocab_size; a < vocab_size * 2; a++) count[a] = 1e15;
pos1 = vocab_size - 1;
pos2 = vocab_size;
// Following algorithm constructs the Huffman tree by adding one node at a time
for (a = 0; a < vocab_size - 1; a++) {
// First, find two smallest nodes 'min1, min2'
if (pos1 >= 0) {
if (count[pos1] < count[pos2]) {
min1i = pos1;
pos1--;
} else {
min1i = pos2;
pos2++;
}
} else {
min1i = pos2;
pos2++;
}
if (pos1 >= 0) {
if (count[pos1] < count[pos2]) {
min2i = pos1;
pos1--;
} else {
min2i = pos2;
pos2++;
}
} else {
min2i = pos2;
pos2++;
}
count[vocab_size + a] = count[min1i] + count[min2i];
// 并查集
parent_node[min1i] = vocab_size + a;
parent_node[min2i] = vocab_size + a;
binary[min2i] = 1;
}
// Now assign binary code to each vocabulary word
for (a = 0; a < vocab_size; a++) {
b = a;
i = 0;
while (1) {
code[i] = binary[b];
point[i] = b;
i++;
b = parent_node[b];
if (b == vocab_size * 2 - 2) break;
}
// 上面while循环是倒序查Huffman编码(并查集找父节点),所以下面需要恢复正序
vocab[a].codelen = i;
vocab[a].point[0] = vocab_size - 2;
for (b = 0; b < i; b++) {
vocab[a].code[i - b - 1] = code[b]; // 编码的反转
// 记录的是从根结点到叶子节点的路径: point[i] + vocab_size 就是根节点到叶子节点路径
// 注意,point代表的是节点,而code代表的是路径
vocab[a].point[i - b] = point[b] - vocab_size;
}
}
free(count);
free(binary);
free(parent_node);
}
// 根据词汇频率做的负采样表
void InitUnigramTable() {
int a, i;
double train_words_pow = 0;
double d1, power = 0.75;
table = (int *)malloc(table_size * sizeof(int));
for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power);
i = 0;
d1 = pow(vocab[i].cn, power) / train_words_pow;
for (a = 0; a < table_size; a++) {
table[a] = i;
if (a / (double)table_size > d1) {
i++;
d1 += pow(vocab[i].cn, power) / train_words_pow;
}
// 如果遍历完所有词汇,负采样表剩余的空间一直拿最后一个单词的索引填充?
if (i >= vocab_size) i = vocab_size - 1;
}
}