-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
131 lines (95 loc) · 5.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 16 11:47:55 2021
@author: Administrator
"""
import torch
from torch.utils.data import DataLoader
import numpy as np
from dataset import Split
from setting import Setting
from dataloader import PoiDataLoader
from trainer import FlashbackTrainer
from network import create_h0_strategy
from evaluation import Evaluation
setting = Setting()
setting.parse()
print(setting)
def normlize_column(array):
len, width = array.shape
temp = np.zeros((len,width), array.dtype)
for i in range(width):
array_temp = array[:,i]/np.sum(array[:,i])
temp[:,i] = array_temp
return temp
# for training set
poi_loader = PoiDataLoader(setting.min_checkins)
poi_loader.read(setting.dataset_file)
poicatg_matrix = poi_loader.zero_matrix
poicatg_matrix = normlize_column(poicatg_matrix)
poicatg_matrix = torch.from_numpy(poicatg_matrix).to(setting.device)
dataset = poi_loader.create_dataset(setting.sequence_length, Split.TRAIN)
dataloader = DataLoader(dataset, batch_size=setting.batch_size, shuffle=True)
# for validate set
dataset_validate = poi_loader.create_dataset(setting.sequence_length, Split.VALIDATE)
dataloader_validata = DataLoader(dataset_validate, batch_size=setting.batch_size, shuffle=False)
# for test settp
dataset_test = poi_loader.create_dataset(setting.sequence_length, Split.TEST)
dataloader_test = DataLoader(dataset_test, batch_size=setting.batch_size, shuffle=False)
# for training
trainer = FlashbackTrainer(setting.lambda_t, setting.lambda_s, poicatg_matrix)
h0_strategy = create_h0_strategy(setting.hidden_dim, setting.is_lstm)
trainer.prepare(poi_loader.poi_count(), poi_loader.user_count(), poi_loader.catg_count(), poi_loader.catgLyaer_count(), poi_loader.timeslot_count(), poi_loader.poi2coord, setting.hidden_dim, setting.rnn_factory, setting.device)
# for Evaluation
evaluation_vaid = Evaluation(dataset_validate, dataloader_validata, poi_loader.user_count(), h0_strategy, trainer, setting)
evaluation_test = Evaluation(dataset_test, dataloader_test, poi_loader.user_count(), h0_strategy, trainer, setting)
# for Optimization
optimizer = torch.optim.Adam(trainer.parameters(), lr=setting.learning_rate, weight_decay=setting.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,80,120,160], gamma=0.6)
poi2coord = torch.tensor(poi_loader.poi2coord, device=setting.device)
print('dataset_name',setting.dataset_name)
print('Seqence_len:',setting.sequence_length)
print('lambda_s',setting.lambda_s)
print('lambda_t',setting.lambda_t)
print('hidden', setting.hidden_dim)
for epoch in range(1):
h = h0_strategy.on_init(setting.batch_size, device = setting.device)
loss_value = []
print('epoch:',epoch+1)
for i,(x_user, x_tf, x_tb, x_tsf, x_tsb, x_cof, x_cob, x_poi_f, x_poi_b, x_catg_f, x_catg_b, x_catgLayer_f, x_catgLayer_b, y_tsecond, y_tslot, y_coord, y_poi, y_catg, y_catgLayer) in enumerate(dataloader):
length = len(x_user)
h = h[:,:length,:]
x_user = x_user.squeeze().to(setting.device)
x_tf = torch.transpose(x_tf.squeeze(),0,1).to(setting.device)
x_tb = torch.transpose(x_tb.squeeze(),0,1).to(setting.device)
x_tsf = torch.transpose(x_tsf.squeeze(),0,1).to(setting.device)
x_tsb = torch.transpose(x_tsb.squeeze(),0,1).to(setting.device)
x_cof = torch.transpose(x_cof.squeeze(),0,1).to(setting.device)
x_cob = torch.transpose(x_cob.squeeze(),0,1).to(setting.device)
x_poi_f = torch.transpose(x_poi_f.squeeze(),0,1).to(setting.device)
x_poi_b = torch.transpose(x_poi_b.squeeze(),0,1).to(setting.device)
x_catg_f = torch.transpose(x_catg_f.squeeze(),0,1).to(setting.device)
x_catg_b = torch.transpose(x_catg_b.squeeze(),0,1).to(setting.device)
x_catgLayer_f = torch.transpose(x_catgLayer_f.squeeze(),0,1).to(setting.device)
x_catgLayer_b = torch.transpose(x_catgLayer_b.squeeze(),0,1).to(setting.device)
y_tsecond = y_tsecond.squeeze().to(setting.device)
y_tslot = y_tslot.squeeze().to(setting.device)
y_coord = y_coord.squeeze().to(setting.device)
y_poi = y_poi.squeeze().to(setting.device)
y_catg = y_catg.squeeze().to(setting.device)
y_catgLayer = y_catgLayer.squeeze().to(setting.device)
y_pred_poi, y_pred_catgLayer = trainer.loss1(h, x_user, x_tf, x_tb, x_tsf, x_tsb, x_cof, x_cob, x_poi_f, x_poi_b, x_catg_f, x_catg_b, x_catgLayer_f, x_catgLayer_b, y_tsecond, y_tslot, y_coord, y_poi, y_catg, y_catgLayer)
#
#
optimizer.zero_grad()
loss = trainer.loss(h, x_user, x_tf, x_tb, x_tsf, x_tsb, x_cof, x_cob, x_poi_f, x_poi_b, x_catg_f, x_catg_b, x_catgLayer_f, x_catgLayer_b, y_tsecond, y_tslot, y_coord, y_poi, y_catg, y_catgLayer)
loss.backward()
loss_value.append(loss.item())
optimizer.step()
scheduler.step()
# print(f'Used learning rate: {scheduler.get_last_lr()[0]}')
print(np.mean(loss_value))
if (epoch ) % 2 == 0:
# evaluation_vaid.evaluate()
# print('*****************')
evaluation_test.evaluate()