-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
executable file
·305 lines (240 loc) · 13.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import torch
import torch.nn as nn
import numpy as np
import time
from memory import DKVMN
import numpy
from itertools import zip_longest
import warnings
warnings.filterwarnings("ignore")
class MODEL(nn.Module):
def __init__(self, n_question, batch_size, q_embed_dim, qa_embed_dim,
memory_size, memory_key_state_dim, memory_value_state_dim, final_fc_dim, first_k, gpu, student_num=None):
super(MODEL, self).__init__()
self.n_question = n_question
self.batch_size = batch_size
self.q_embed_dim = q_embed_dim
self.qa_embed_dim = qa_embed_dim
self.memory_size = memory_size
self.memory_key_state_dim = memory_key_state_dim
self.memory_value_state_dim = memory_value_state_dim
self.final_fc_dim = final_fc_dim
self.student_num = student_num
self.first_k = first_k
self.read_embed_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, self.final_fc_dim, bias=True)
# self.predict_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, 1, bias=True)
self.init_memory_key = nn.Parameter(torch.randn(self.memory_size, self.memory_key_state_dim))
nn.init.kaiming_normal_(self.init_memory_key)
self.init_memory_value = nn.Parameter(torch.randn(self.memory_size, self.memory_value_state_dim))
nn.init.kaiming_normal_(self.init_memory_value)
# modify hop_lstm
self.hop_lstm = nn.LSTM(input_size=self.memory_value_state_dim + self.q_embed_dim, hidden_size=64, num_layers=1, batch_first=True)
# hidden_size = 64
self.predict_linear = nn.Linear(64, 1, bias=True)
self.mem = DKVMN(memory_size=self.memory_size,
memory_key_state_dim=self.memory_key_state_dim,
memory_value_state_dim=self.memory_value_state_dim, init_memory_key=self.init_memory_key)
memory_value = nn.Parameter(torch.cat([self.init_memory_value.unsqueeze(0) for _ in range(batch_size)], 0).data)
self.mem.init_value_memory(memory_value)
# 题目序号从1开始
# nn.embedding输入是一个下标的列标,输出是对应的嵌入
self.q_embed = nn.Embedding(self.n_question + 1, self.q_embed_dim, padding_idx=0)
self.a_embed = nn.Linear(2 * self.n_question + 1, self.qa_embed_dim, bias=True)
# self.a_embed = nn.Linear(self.final_fc_dim + 1, self.qa_embed_dim, bias=True)
# self.correlation_weight_list = []
if gpu >= 0:
self.device = torch.device('cuda', gpu)
else:
self.device = torch.device('cpu')
print("num_layers=1, hidden_size=64, a=0.075, b=0.088, c=1.00, triangular, onehot")
def init_params(self):
nn.init.kaiming_normal_(self.predict_linear.weight)
nn.init.kaiming_normal_(self.read_embed_linear.weight)
nn.init.constant_(self.read_embed_linear.bias, 0)
nn.init.constant_(self.predict_linear.bias, 0)
def init_embeddings(self):
nn.init.kaiming_normal_(self.q_embed.weight)
# 方法2:权重向量的topk置1
def identity_layer(self, correlation_weight, seqlen, k=1):
batch_identity_indices = []
correlation_weight = correlation_weight.view(self.batch_size * seqlen, -1)
# 把batch中每一格sequence中topk置1,其余置0
_, indices = correlation_weight.topk(k, dim=1, largest=True)
identity_matrix = torch.zeros([self.batch_size * seqlen, self.memory_size])
for i, m in enumerate(indices):
identity_matrix[i, m] = 1
identity_vector_batch = identity_matrix.view(self.batch_size * seqlen, -1)
unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0)
self.unique_len = unique_iv.shape[0]
# A^2
iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2), dim=1, keepdim=True)
iv_square_norm = iv_square_norm.repeat((1, self.unique_len))
# B^2.T
unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2), dim=1, keepdim=True)
unique_iv_square_norm = unique_iv_square_norm.repeat((1, self.batch_size * seqlen)).transpose(1, 0)
# A * B.T
iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0))
# A^2 + B^2 - 2A*B.T
iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product
indices = (iv_distances == 0).nonzero()
batch_identity_indices = indices[:, -1]
return batch_identity_indices
# 方法1:用三角隶属函数计算identity向量
def triangular_layer(self, correlation_weight, seqlen, a=0.075, b=0.088, c=1.00):
batch_identity_indices = []
# w'= max((w-a)/(b-a), (c-w)/(c-b))
# min(w', 0)
correlation_weight = correlation_weight.view(self.batch_size * seqlen, -1)
correlation_weight = torch.cat([correlation_weight[i] for i in range(correlation_weight.shape[0])], 0).unsqueeze(0)
correlation_weight = torch.cat([(correlation_weight-a)/(b-a), (c-correlation_weight)/(c-b)], 0)
correlation_weight, _ = torch.min(correlation_weight, 0)
w0 = torch.zeros(correlation_weight.shape[0]).to(self.device)
correlation_weight = torch.cat([correlation_weight.unsqueeze(0), w0.unsqueeze(0)], 0)
correlation_weight, _ = torch.max(correlation_weight, 0)
identity_vector_batch = torch.zeros(correlation_weight.shape[0]).to(self.device)
# >=0.6的值置2,0.1-0.6的值置1,0.1以下的值置0
# mask = correlation_weight.lt(0.1)
identity_vector_batch = identity_vector_batch.masked_fill(correlation_weight.lt(0.1), 0)
# mask = correlation_weight.ge(0.1)
identity_vector_batch = identity_vector_batch.masked_fill(correlation_weight.ge(0.1), 1)
# mask = correlation_weight.ge(0.6)
_identity_vector_batch = identity_vector_batch.masked_fill(correlation_weight.ge(0.6), 2)
# identity_vector_batch = torch.chunk(identity_vector_batch.view(self.batch_size, -1), self.batch_size, 0)
# 输入:_identity_vector_batch
# 输出:indices
identity_vector_batch = _identity_vector_batch.view(self.batch_size * seqlen, -1)
unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0)
self.unique_len = unique_iv.shape[0]
# A^2
iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2), dim=1, keepdim=True)
iv_square_norm = iv_square_norm.repeat((1, self.unique_len))
# B^2.T
unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2), dim=1, keepdim=True)
unique_iv_square_norm = unique_iv_square_norm.repeat((1, self.batch_size * seqlen)).transpose(1, 0)
# A * B.T
iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0))
# A^2 + B^2 - 2A*B.T
iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product
indices = (iv_distances == 0).nonzero()
batch_identity_indices = indices[:, -1]
return batch_identity_indices
def forward(self, q_data, qa_data, a_data, target, student_id=None):
batch_size = q_data.shape[0] #32
seqlen = q_data.shape[1] #200
## qt && (q,a) embedding
q_embed_data = self.q_embed(q_data)
# modify 生成每道题对应的yt onehot向量
a_onehot_array = []
for i in range(a_data.shape[0]):
for j in range(a_data.shape[1]):
a_onehot = np.zeros(self.n_question + 1)
index = a_data[i][j]
if index > 0:
a_onehot[index] = 1
a_onehot_array.append(a_onehot)
a_onehot_content = torch.cat([torch.Tensor(a_onehot_array[i]).unsqueeze(0) for i in range(len(a_onehot_array))], 0)
a_onehot_content = a_onehot_content.view(batch_size, seqlen, -1).to(self.device)
## copy mk batch times for dkvmn
memory_value = nn.Parameter(torch.cat([self.init_memory_value.unsqueeze(0) for _ in range(batch_size)], 0).data)
self.mem.init_value_memory(memory_value)
## slice data for seqlen times by axis 1
slice_q_data = torch.chunk(q_data, seqlen, 1)
slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)
# modify
slice_a_onehot_content = torch.chunk(a_onehot_content, seqlen, 1)
# slice_a = torch.chunk(a_data, seqlen, 1)
value_read_content_l = []
input_embed_l = []
correlation_weight_list = []
# modify
f_t = []
# (n_layers,batch_size,hidden_dim)
init_h = torch.randn(1, self.batch_size, 64).to(self.device)
init_c = torch.randn(1, self.batch_size, 64).to(self.device)
for i in range(seqlen):
## Attention
q = slice_q_embed_data[i].squeeze(1)
correlation_weight = self.mem.attention(q)
## Read Process
read_content = self.mem.read(correlation_weight)
# modify
correlation_weight_list.append(correlation_weight)
## save intermedium data
value_read_content_l.append(read_content)
input_embed_l.append(q)
# modify
batch_predict_input = torch.cat([read_content, q], 1)
f = self.read_embed_linear(batch_predict_input)
f_t.append(batch_predict_input)
# 写入value矩阵的输入为[yt, ft],onehot向量和ft向量拼接
onehot = slice_a_onehot_content[i].squeeze(1)
write_embed = torch.cat([onehot, f], 1)
# 写入value矩阵的输入为[ft, yt],ft直接和题目对错(0或1)拼接
# write_embed = torch.cat([f, slice_a[i].float()], 1)
write_embed = self.a_embed(write_embed)
new_memory_value = self.mem.write(correlation_weight, write_embed)
# modify
correlation_weight_matrix = torch.cat([correlation_weight_list[i].unsqueeze(1) for i in range(seqlen)], 1)
identity_index_list = self.triangular_layer(correlation_weight_matrix, seqlen)
# identity_index_list = self.identity_layer(correlation_weight_matrix, seqlen)
identity_index_list = identity_index_list.view(self.batch_size, seqlen)
# identity_index_list = identity_index_list[:, self.first_k:] # 前k个不进行预测
# identity_index_list = torch.cat([identity_index_list[i].unsqueeze(1) for i in range(seqlen)], 1)
f_t = torch.cat([f_t[i].unsqueeze(1) for i in range(seqlen)], 1)
# f_t = f_t[:, self.first_k:] # 前k个不进行预测
target_seqlayer = target.view(batch_size, seqlen, -1)
# target_seqlayer = target_seqlayer[:, self.first_k:] # 前k个不进行预测
target_sequence = []
pred_sequence = []
for idx in range(self.unique_len):
# start = time.time()
hop_lstm_input = []
hop_lstm_target = []
max_seq = 1
zero_count = 0
for i in range(self.batch_size):
# 获取每个sequence中和当前要进行预测的identity向量对应的题目在矩阵中的index
index = list((identity_index_list[i,:]==idx).nonzero())
max_seq = max(max_seq, len(index))
if len(index) == 0:
hop_lstm_input.append(torch.zeros([1, self.memory_value_state_dim + self.q_embed_dim]))
hop_lstm_target.append(torch.full([1, 1], -1))
zero_count += 1
continue
else:
index = torch.LongTensor(index).to(self.device)
hop_lstm_target_slice = torch.index_select(target_seqlayer[i, :, :], 0, index)
hop_lstm_input_slice = torch.index_select(f_t[i, :, :], 0, index)
hop_lstm_input.append(hop_lstm_input_slice)
hop_lstm_target.append(hop_lstm_target_slice)
if zero_count == 32:
continue
# 给输入矩阵和target矩阵做padding
for i in range(self.batch_size):
x = torch.zeros([max_seq, self.memory_value_state_dim + self.q_embed_dim])
x[:len(hop_lstm_input[i]), :] = hop_lstm_input[i]
hop_lstm_input[i] = x
y = torch.full([max_seq, 1], -1)
y[:len(hop_lstm_target[i]), :] = hop_lstm_target[i]
hop_lstm_target[i] = y
# hop lstm进行预测
hop_lstm_input = torch.cat([hop_lstm_input[i].unsqueeze(0) for i in range(self.batch_size)], 0).to(self.device)
hop_lstm_target = torch.cat([hop_lstm_target[i].unsqueeze(0) for i in range(self.batch_size)], 0)
hop_lstm_output, _ = self.hop_lstm(hop_lstm_input, (init_h, init_c))
pred = self.predict_linear(hop_lstm_output)
pred = pred.view(self.batch_size * max_seq, -1)
hop_lstm_target = hop_lstm_target.view(self.batch_size * max_seq, -1).to(self.device)
mask = hop_lstm_target.ge(0)
hop_lstm_target = torch.masked_select(hop_lstm_target, mask)
pred = torch.sigmoid(torch.masked_select(pred, mask))
target_sequence.append(hop_lstm_target)
pred_sequence.append(pred)
# 在训练阶段对每个identity向量对应的lstm分别进行反向传播
if self.training is True:
subsequence_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, hop_lstm_target)
subsequence_loss.backward(retain_graph=True)
# 计算一个batch全部题目的loss
target_sequence = torch.cat([target_sequence[i] for i in range(len(target_sequence))], 0)
pred_sequence = torch.cat([pred_sequence[i] for i in range(len(pred_sequence))], 0)
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_sequence, target_sequence)
return loss, pred_sequence, target_sequence