16
16
import model .plot as plot
17
17
from architecture .single_model import img_alexnet_layers
18
18
from evaluation import MAPs
19
- from .util import Dataset
20
19
21
20
22
- class PruneHash (object ):
23
- def __init__ (self , config , stage ):
21
+ class DCH (object ):
22
+ def __init__ (self , config ):
24
23
### Initialize setting
25
24
print ("initializing" )
26
25
np .set_printoptions (precision = 4 )
27
- self .stage = stage
28
- self .device = config ['device' ]
29
- self .output_dim = config ['output_dim' ]
30
- self .n_class = config ['label_dim' ]
31
- self .cq_lambda = config ['cq_lambda' ]
32
- self .alpha = config ['alpha' ]
33
- self .bias = config ['bias' ]
34
- self .gamma = config ['gamma' ]
35
-
36
- self .batch_size = config ['batch_size' ] if self .stage == "train" else config ['val_batch_size' ]
37
- self .max_iter = config ['max_iter' ]
38
- self .img_model = config ['img_model' ]
39
- self .loss_type = config ['loss_type' ]
40
- self .learning_rate = config ['learning_rate' ]
41
- self .learning_rate_decay_factor = config ['learning_rate_decay_factor' ]
42
- self .decay_step = config ['decay_step' ]
43
-
44
- self .finetune_all = config ['finetune_all' ]
45
26
27
+ with tf .name_scope ('stage' ):
28
+ # 0 for training, 1 for validation
29
+ self .stage = tf .placeholder_with_default (tf .constant (0 ), [])
30
+ for k , v in vars (config ).items ():
31
+ setattr (self , k , v )
46
32
self .file_name = 'loss_{}_lr_{}_cqlambda_{}_alpha_{}_bias_{}_gamma_{}_dataset_{}' .format (
47
33
self .loss_type ,
48
- self .learning_rate ,
49
- self .cq_lambda ,
34
+ self .lr ,
35
+ self .q_lambda ,
50
36
self .alpha ,
51
37
self .bias ,
52
38
self .gamma ,
53
- config ['dataset' ])
54
- self .save_dir = config ['save_dir' ]
55
- self .save_file = os .path .join (config ['save_dir' ], self .file_name + '.npy' )
56
- self .log_dir = config ['log_dir' ]
39
+ self .dataset )
40
+ self .save_file = os .path .join (self .save_dir , self .file_name + '.npy' )
57
41
58
42
### Setup session
59
43
print ("launching session" )
@@ -63,27 +47,25 @@ def __init__(self, config, stage):
63
47
self .sess = tf .Session (config = configProto )
64
48
65
49
### Create variables and placeholders
50
+ self .img = tf .placeholder (tf .float32 , [None , 256 , 256 , 3 ])
51
+ self .img_label = tf .placeholder (tf .float32 , [None , self .label_dim ])
52
+ self .img_last_layer , self .deep_param_img , self .train_layers , self .train_last_layer = self .load_model ()
66
53
67
- with tf .device (self .device ):
68
- self .img = tf .placeholder (tf .float32 , [self .batch_size , 256 , 256 , 3 ])
69
- self .img_label = tf .placeholder (tf .float32 , [self .batch_size , self .n_class ])
70
-
71
- if self .stage == 'train' :
72
- self .model_weights = config ['model_weights' ]
73
- else :
74
- self .model_weights = self .save_file
75
- self .img_last_layer , self .deep_param_img , self .train_layers , self .train_last_layer = self .load_model ()
76
-
77
- self .global_step = tf .Variable (0 , trainable = False )
78
- self .train_op = self .apply_loss_function (self .global_step )
79
- self .sess .run (tf .global_variables_initializer ())
54
+ self .global_step = tf .Variable (0 , trainable = False )
55
+ self .train_op = self .apply_loss_function (self .global_step )
56
+ self .sess .run (tf .global_variables_initializer ())
80
57
return
81
58
82
59
def load_model (self ):
83
60
if self .img_model == 'alexnet' :
84
61
img_output = img_alexnet_layers (
85
- self .img , self .batch_size , self .output_dim ,
86
- self .stage , self .model_weights )
62
+ self .img ,
63
+ self .batch_size ,
64
+ self .output_dim ,
65
+ self .stage ,
66
+ self .model_weights ,
67
+ self .with_tanh ,
68
+ self .val_batch_size )
87
69
else :
88
70
raise Exception ('cannot use such CNN model as ' + self .img_model )
89
71
return img_output
@@ -139,7 +121,7 @@ def reduce_shaper(t):
139
121
r = tf .reshape (r , [- 1 , 1 ])
140
122
ip = r - 2 * tf .matmul (u , tf .transpose (u )) + tf .transpose (r )
141
123
142
- ip = tf . constant ( self .gamma ) / (ip + tf . constant ( self .gamma ) * tf . constant ( self . gamma ) )
124
+ ip = self .gamma / (ip + self .gamma ** 2 )
143
125
else :
144
126
ip = tf .clip_by_value (tf .matmul (u , tf .transpose (u )), - 1.5e1 , 1.5e1 )
145
127
ones = tf .ones ([tf .shape (u )[0 ], tf .shape (u )[0 ]])
@@ -158,13 +140,12 @@ def apply_loss_function(self, global_step):
158
140
self .cos_loss = self .cross_entropy (self .img_last_layer , self .img_label , self .alpha , True , True , self .bias )
159
141
160
142
self .q_loss_img = tf .reduce_mean (tf .square (tf .subtract (tf .abs (self .img_last_layer ), tf .constant (1.0 ))))
161
- self .q_lambda = tf .Variable (self .cq_lambda , name = 'cq_lambda' )
162
- self .q_loss = tf .multiply (self .q_lambda , self .q_loss_img )
143
+ self .q_loss = self .q_lambda * self .q_loss_img
163
144
self .loss = self .cos_loss + self .q_loss
164
145
165
146
### Last layer has a 10 times learning rate
166
- self . lr = tf .train .exponential_decay (self .learning_rate , global_step , self .decay_step , self .learning_rate_decay_factor , staircase = True )
167
- opt = tf .train .MomentumOptimizer (learning_rate = self . lr , momentum = 0.9 )
147
+ lr = tf .train .exponential_decay (self .lr , global_step , self .decay_step , self .lr , staircase = True )
148
+ opt = tf .train .MomentumOptimizer (learning_rate = lr , momentum = 0.9 )
168
149
grads_and_vars = opt .compute_gradients (self .loss , self .train_layers + self .train_last_layer )
169
150
fcgrad , _ = grads_and_vars [- 2 ]
170
151
fbgrad , _ = grads_and_vars [- 1 ]
@@ -174,11 +155,11 @@ def apply_loss_function(self, global_step):
174
155
tf .summary .scalar ('loss' , self .loss )
175
156
tf .summary .scalar ('cos_loss' , self .cos_loss )
176
157
tf .summary .scalar ('q_loss' , self .q_loss )
177
- tf .summary .scalar ('lr' , self . lr )
158
+ tf .summary .scalar ('lr' , lr )
178
159
self .merged = tf .summary .merge_all ()
179
160
180
161
181
- if self .stage == "train" and self . finetune_all :
162
+ if self .finetune_all :
182
163
return opt .apply_gradients ([(grads_and_vars [0 ][0 ], self .train_layers [0 ]),
183
164
(grads_and_vars [1 ][0 ]* 2 , self .train_layers [1 ]),
184
165
(grads_and_vars [2 ][0 ], self .train_layers [2 ]),
@@ -208,13 +189,10 @@ def train(self, img_dataset):
208
189
shutil .rmtree (tflog_path )
209
190
train_writer = tf .summary .FileWriter (tflog_path , self .sess .graph )
210
191
211
- for train_iter in range (self .max_iter ):
192
+ for train_iter in range (self .iter_num ):
212
193
images , labels = img_dataset .next_batch (self .batch_size )
213
194
start_time = time .time ()
214
195
215
- assign_lambda = self .q_lambda .assign (self .cq_lambda )
216
- self .sess .run ([assign_lambda ])
217
-
218
196
_ , loss , cos_loss , output , summary = self .sess .run ([self .train_op , self .loss , self .cos_loss , self .img_last_layer , self .merged ],
219
197
feed_dict = {self .img : images ,
220
198
self .img_label : labels })
@@ -224,7 +202,7 @@ def train(self, img_dataset):
224
202
img_dataset .feed_batch_output (self .batch_size , output )
225
203
duration = time .time () - start_time
226
204
227
- if train_iter % 1 == 0 :
205
+ if train_iter % 100 == 0 :
228
206
print ("%s #train# step %4d, loss = %.4f, cross_entropy loss = %.4f, %.1f sec/batch"
229
207
% (datetime .now (), train_iter + 1 , loss , cos_loss , duration ))
230
208
@@ -236,24 +214,29 @@ def train(self, img_dataset):
236
214
237
215
def validation (self , img_query , img_database , R = 100 ):
238
216
print ("%s #validation# start validation" % (datetime .now ()))
239
- query_batch = int (ceil (img_query .n_samples / self .batch_size ))
217
+ query_batch = int (ceil (img_query .n_samples / float (self .val_batch_size )))
218
+ img_query .finish_epoch ()
240
219
print ("%s #validation# totally %d query in %d batches" % (datetime .now (), img_query .n_samples , query_batch ))
241
220
for i in range (query_batch ):
242
- images , labels = img_query .next_batch (self .batch_size )
221
+ images , labels = img_query .next_batch (self .val_batch_size )
243
222
output , loss = self .sess .run ([self .img_last_layer , self .cos_loss ],
244
- feed_dict = {self .img : images , self .img_label : labels })
245
- img_query .feed_batch_output (self .batch_size , output )
223
+ feed_dict = {self .img : images ,
224
+ self .img_label : labels ,
225
+ self .stage : 1 })
226
+ img_query .feed_batch_output (self .val_batch_size , output )
246
227
print ('Cosine Loss: %s' % loss )
247
228
248
- database_batch = int (ceil (img_database .n_samples / self .batch_size ))
229
+ database_batch = int (ceil (img_database .n_samples / float (self .val_batch_size )))
230
+ img_database .finish_epoch ()
249
231
print ("%s #validation# totally %d database in %d batches" % (datetime .now (), img_database .n_samples , database_batch ))
250
232
for i in range (database_batch ):
251
- images , labels = img_database .next_batch (self .batch_size )
233
+ images , labels = img_database .next_batch (self .val_batch_size )
252
234
253
235
output , loss = self .sess .run ([self .img_last_layer , self .cos_loss ],
254
- feed_dict = {self .img : images , self .img_label : labels })
255
- img_database .feed_batch_output (self .batch_size , output )
256
- #print output[:10, :10]
236
+ feed_dict = {self .img : images ,
237
+ self .img_label : labels ,
238
+ self .stage : 1 })
239
+ img_database .feed_batch_output (self .val_batch_size , output )
257
240
if i % 100 == 0 :
258
241
print ('Cosine Loss[%d/%d]: %s' % (i , database_batch , loss ))
259
242
@@ -283,15 +266,3 @@ def validation(self, img_query, img_database, R=100):
283
266
'i2i_map_radius_2' : mmap ,
284
267
}
285
268
286
- def train (train_img , config ):
287
- model = PruneHash (config , 'train' )
288
- img_dataset = Dataset (train_img , config ['output_dim' ])
289
- model .train (img_dataset )
290
- return model .save_file
291
-
292
- def validation (database_img , query_img , config ):
293
- model = PruneHash (config , 'val' )
294
- img_database = Dataset (database_img , config ['output_dim' ])
295
- img_query = Dataset (query_img , config ['output_dim' ])
296
- return model .validation (img_query , img_database , config ['R' ])
297
-
0 commit comments