Skip to content

Commit 4499aee

Browse files
committed
update dch
1 parent 5d17857 commit 4499aee

File tree

8 files changed

+133
-169
lines changed

8 files changed

+133
-169
lines changed

DeepHash/model/dch/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .util import Dataset
2+
from .dch import DCH
3+
4+
def train(train_img, database_img, query_img, config):
5+
model = DCH(config)
6+
img_database = Dataset(database_img, config.output_dim)
7+
img_query = Dataset(query_img, config.output_dim)
8+
img_train = Dataset(train_img, config.output_dim)
9+
model.train(img_train)
10+
return model.save_dir
11+
12+
def validation(database_img, query_img, config):
13+
model = DCH(config)
14+
img_database = Dataset(database_img, config.output_dim)
15+
img_query = Dataset(query_img, config.output_dim)
16+
return model.validation(img_query, img_database, config.R)

DeepHash/model/prunehash/prunehash.py DeepHash/model/dch/dch.py

+46-75
Original file line numberDiff line numberDiff line change
@@ -16,44 +16,28 @@
1616
import model.plot as plot
1717
from architecture.single_model import img_alexnet_layers
1818
from evaluation import MAPs
19-
from .util import Dataset
2019

2120

22-
class PruneHash(object):
23-
def __init__(self, config, stage):
21+
class DCH(object):
22+
def __init__(self, config):
2423
### Initialize setting
2524
print ("initializing")
2625
np.set_printoptions(precision=4)
27-
self.stage = stage
28-
self.device = config['device']
29-
self.output_dim = config['output_dim']
30-
self.n_class = config['label_dim']
31-
self.cq_lambda = config['cq_lambda']
32-
self.alpha = config['alpha']
33-
self.bias = config['bias']
34-
self.gamma = config['gamma']
35-
36-
self.batch_size = config['batch_size'] if self.stage == "train" else config['val_batch_size']
37-
self.max_iter = config['max_iter']
38-
self.img_model = config['img_model']
39-
self.loss_type = config['loss_type']
40-
self.learning_rate = config['learning_rate']
41-
self.learning_rate_decay_factor = config['learning_rate_decay_factor']
42-
self.decay_step = config['decay_step']
43-
44-
self.finetune_all = config['finetune_all']
4526

27+
with tf.name_scope('stage'):
28+
# 0 for training, 1 for validation
29+
self.stage = tf.placeholder_with_default(tf.constant(0), [])
30+
for k, v in vars(config).items():
31+
setattr(self, k, v)
4632
self.file_name = 'loss_{}_lr_{}_cqlambda_{}_alpha_{}_bias_{}_gamma_{}_dataset_{}'.format(
4733
self.loss_type,
48-
self.learning_rate,
49-
self.cq_lambda,
34+
self.lr,
35+
self.q_lambda,
5036
self.alpha,
5137
self.bias,
5238
self.gamma,
53-
config['dataset'])
54-
self.save_dir = config['save_dir']
55-
self.save_file = os.path.join(config['save_dir'], self.file_name + '.npy')
56-
self.log_dir = config['log_dir']
39+
self.dataset)
40+
self.save_file = os.path.join(self.save_dir, self.file_name + '.npy')
5741

5842
### Setup session
5943
print ("launching session")
@@ -63,27 +47,25 @@ def __init__(self, config, stage):
6347
self.sess = tf.Session(config=configProto)
6448

6549
### Create variables and placeholders
50+
self.img = tf.placeholder(tf.float32, [None, 256, 256, 3])
51+
self.img_label = tf.placeholder(tf.float32, [None, self.label_dim])
52+
self.img_last_layer, self.deep_param_img, self.train_layers, self.train_last_layer = self.load_model()
6653

67-
with tf.device(self.device):
68-
self.img = tf.placeholder(tf.float32, [self.batch_size, 256, 256, 3])
69-
self.img_label = tf.placeholder(tf.float32, [self.batch_size, self.n_class])
70-
71-
if self.stage == 'train':
72-
self.model_weights = config['model_weights']
73-
else:
74-
self.model_weights = self.save_file
75-
self.img_last_layer, self.deep_param_img, self.train_layers, self.train_last_layer = self.load_model()
76-
77-
self.global_step = tf.Variable(0, trainable=False)
78-
self.train_op = self.apply_loss_function(self.global_step)
79-
self.sess.run(tf.global_variables_initializer())
54+
self.global_step = tf.Variable(0, trainable=False)
55+
self.train_op = self.apply_loss_function(self.global_step)
56+
self.sess.run(tf.global_variables_initializer())
8057
return
8158

8259
def load_model(self):
8360
if self.img_model == 'alexnet':
8461
img_output = img_alexnet_layers(
85-
self.img, self.batch_size, self.output_dim,
86-
self.stage, self.model_weights)
62+
self.img,
63+
self.batch_size,
64+
self.output_dim,
65+
self.stage,
66+
self.model_weights,
67+
self.with_tanh,
68+
self.val_batch_size)
8769
else:
8870
raise Exception('cannot use such CNN model as ' + self.img_model)
8971
return img_output
@@ -139,7 +121,7 @@ def reduce_shaper(t):
139121
r = tf.reshape(r, [-1, 1])
140122
ip = r - 2*tf.matmul(u, tf.transpose(u)) + tf.transpose(r)
141123

142-
ip = tf.constant(self.gamma) / (ip + tf.constant(self.gamma)*tf.constant(self.gamma))
124+
ip = self.gamma / (ip + self.gamma ** 2)
143125
else:
144126
ip = tf.clip_by_value(tf.matmul(u, tf.transpose(u)), -1.5e1, 1.5e1)
145127
ones = tf.ones([tf.shape(u)[0], tf.shape(u)[0]])
@@ -158,13 +140,12 @@ def apply_loss_function(self, global_step):
158140
self.cos_loss = self.cross_entropy(self.img_last_layer, self.img_label, self.alpha, True, True, self.bias)
159141

160142
self.q_loss_img = tf.reduce_mean(tf.square(tf.subtract(tf.abs(self.img_last_layer), tf.constant(1.0))))
161-
self.q_lambda = tf.Variable(self.cq_lambda, name='cq_lambda')
162-
self.q_loss = tf.multiply(self.q_lambda, self.q_loss_img)
143+
self.q_loss = self.q_lambda * self.q_loss_img
163144
self.loss = self.cos_loss + self.q_loss
164145

165146
### Last layer has a 10 times learning rate
166-
self.lr = tf.train.exponential_decay(self.learning_rate, global_step, self.decay_step, self.learning_rate_decay_factor, staircase=True)
167-
opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=0.9)
147+
lr = tf.train.exponential_decay(self.lr, global_step, self.decay_step, self.lr, staircase=True)
148+
opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
168149
grads_and_vars = opt.compute_gradients(self.loss, self.train_layers+self.train_last_layer)
169150
fcgrad, _ = grads_and_vars[-2]
170151
fbgrad, _ = grads_and_vars[-1]
@@ -174,11 +155,11 @@ def apply_loss_function(self, global_step):
174155
tf.summary.scalar('loss', self.loss)
175156
tf.summary.scalar('cos_loss', self.cos_loss)
176157
tf.summary.scalar('q_loss', self.q_loss)
177-
tf.summary.scalar('lr', self.lr)
158+
tf.summary.scalar('lr', lr)
178159
self.merged = tf.summary.merge_all()
179160

180161

181-
if self.stage == "train" and self.finetune_all:
162+
if self.finetune_all:
182163
return opt.apply_gradients([(grads_and_vars[0][0], self.train_layers[0]),
183164
(grads_and_vars[1][0]*2, self.train_layers[1]),
184165
(grads_and_vars[2][0], self.train_layers[2]),
@@ -208,13 +189,10 @@ def train(self, img_dataset):
208189
shutil.rmtree(tflog_path)
209190
train_writer = tf.summary.FileWriter(tflog_path, self.sess.graph)
210191

211-
for train_iter in range(self.max_iter):
192+
for train_iter in range(self.iter_num):
212193
images, labels = img_dataset.next_batch(self.batch_size)
213194
start_time = time.time()
214195

215-
assign_lambda = self.q_lambda.assign(self.cq_lambda)
216-
self.sess.run([assign_lambda])
217-
218196
_, loss, cos_loss, output, summary = self.sess.run([self.train_op, self.loss, self.cos_loss, self.img_last_layer, self.merged],
219197
feed_dict={self.img: images,
220198
self.img_label: labels})
@@ -224,7 +202,7 @@ def train(self, img_dataset):
224202
img_dataset.feed_batch_output(self.batch_size, output)
225203
duration = time.time() - start_time
226204

227-
if train_iter % 1 == 0:
205+
if train_iter % 100 == 0:
228206
print("%s #train# step %4d, loss = %.4f, cross_entropy loss = %.4f, %.1f sec/batch"
229207
%(datetime.now(), train_iter+1, loss, cos_loss, duration))
230208

@@ -236,24 +214,29 @@ def train(self, img_dataset):
236214

237215
def validation(self, img_query, img_database, R=100):
238216
print("%s #validation# start validation" % (datetime.now()))
239-
query_batch = int(ceil(img_query.n_samples / self.batch_size))
217+
query_batch = int(ceil(img_query.n_samples / float(self.val_batch_size)))
218+
img_query.finish_epoch()
240219
print("%s #validation# totally %d query in %d batches" % (datetime.now(), img_query.n_samples, query_batch))
241220
for i in range(query_batch):
242-
images, labels = img_query.next_batch(self.batch_size)
221+
images, labels = img_query.next_batch(self.val_batch_size)
243222
output, loss = self.sess.run([self.img_last_layer, self.cos_loss],
244-
feed_dict={self.img: images, self.img_label: labels})
245-
img_query.feed_batch_output(self.batch_size, output)
223+
feed_dict={self.img: images,
224+
self.img_label: labels,
225+
self.stage: 1})
226+
img_query.feed_batch_output(self.val_batch_size, output)
246227
print('Cosine Loss: %s'%loss)
247228

248-
database_batch = int(ceil(img_database.n_samples / self.batch_size))
229+
database_batch = int(ceil(img_database.n_samples / float(self.val_batch_size)))
230+
img_database.finish_epoch()
249231
print("%s #validation# totally %d database in %d batches" % (datetime.now(), img_database.n_samples, database_batch))
250232
for i in range(database_batch):
251-
images, labels = img_database.next_batch(self.batch_size)
233+
images, labels = img_database.next_batch(self.val_batch_size)
252234

253235
output, loss = self.sess.run([self.img_last_layer, self.cos_loss],
254-
feed_dict={self.img: images, self.img_label: labels})
255-
img_database.feed_batch_output(self.batch_size, output)
256-
#print output[:10, :10]
236+
feed_dict={self.img: images,
237+
self.img_label: labels,
238+
self.stage: 1})
239+
img_database.feed_batch_output(self.val_batch_size, output)
257240
if i % 100 == 0:
258241
print('Cosine Loss[%d/%d]: %s'%(i, database_batch, loss))
259242

@@ -283,15 +266,3 @@ def validation(self, img_query, img_database, R=100):
283266
'i2i_map_radius_2': mmap,
284267
}
285268

286-
def train(train_img, config):
287-
model = PruneHash(config, 'train')
288-
img_dataset = Dataset(train_img, config['output_dim'])
289-
model.train(img_dataset)
290-
return model.save_file
291-
292-
def validation(database_img, query_img, config):
293-
model = PruneHash(config, 'val')
294-
img_database = Dataset(database_img, config['output_dim'])
295-
img_query = Dataset(query_img, config['output_dim'])
296-
return model.validation(img_query, img_database, config['R'])
297-
File renamed without changes.

DeepHash/model/prunehash/__init__.py

Whitespace-only changes.

examples/dch/train_val_script.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import argparse
3+
import warnings
4+
import numpy as np
5+
import scipy.io as sio
6+
import model.dch as model
7+
import data_provider.image as dataset
8+
9+
from pprint import pprint
10+
11+
warnings.filterwarnings("ignore", category = DeprecationWarning)
12+
warnings.filterwarnings("ignore", category = FutureWarning)
13+
14+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
15+
16+
parser = argparse.ArgumentParser(description='Triplet Hashing')
17+
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float)
18+
parser.add_argument('--output-dim', default=64, type=int) # 256, 128
19+
parser.add_argument('--alpha', default=0.5, type=float)
20+
parser.add_argument('--bias', default=0.0, type=float)
21+
parser.add_argument('--gamma', default=20, type=float)
22+
parser.add_argument('--iter-num', default=2000, type=int)
23+
parser.add_argument('--q-lambda', default=0, type=float)
24+
parser.add_argument('--dataset', default='cifar10', type=str)
25+
parser.add_argument('--gpus', default='0', type=str)
26+
parser.add_argument('--log-dir', default='tflog', type=str)
27+
parser.add_argument('-b', '--batch-size', default=128, type=int)
28+
parser.add_argument('-vb', '--val-batch-size', default=16, type=int)
29+
parser.add_argument('--decay-step', default=10000, type=int)
30+
parser.add_argument('--decay-factor', default=0.1, type=int)
31+
parser.add_argument('--loss-type', default='pruned_cross_entropy', type=str)
32+
33+
tanh_parser = parser.add_mutually_exclusive_group(required=False)
34+
tanh_parser.add_argument('--with-tanh', dest='with_tanh', action='store_true')
35+
tanh_parser.add_argument('--without-tanh', dest='with_tanh', action='store_false')
36+
parser.set_defaults(with_tanh=True)
37+
38+
parser.add_argument('--img-model', default='alexnet', type=str)
39+
parser.add_argument('--model-weights', type=str,
40+
default='../../DeepHash/architecture/single_model/pretrained_model/reference_pretrain.npy')
41+
parser.add_argument('--finetune-all', default=True, type=bool)
42+
parser.add_argument('--save-dir', default="./models/", type=str)
43+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true')
44+
45+
args = parser.parse_args()
46+
47+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
48+
49+
label_dims = {'cifar10': 10, 'cub': 200, 'nuswide_81': 81, 'coco': 80}
50+
Rs = {'cifar10': 54000, 'nuswide_81': 5000, 'coco': 5000}
51+
args.R = Rs[args.dataset]
52+
args.label_dim = label_dims[args.dataset]
53+
args.img_tr = "/home/caoyue/data/{}/train.txt".format(args.dataset)
54+
args.img_te = "/home/caoyue/data/{}/test.txt".format(args.dataset)
55+
args.img_db = "/home/caoyue/data/{}/database.txt".format(args.dataset)
56+
57+
pprint(vars(args))
58+
59+
query_img, database_img = dataset.import_validation(args.img_te, args.img_db)
60+
61+
if not args.evaluate:
62+
train_img = dataset.import_train(args.img_tr)
63+
model_weights = model.train(train_img, database_img, query_img, args)
64+
args.model_weights = model_weights
65+
66+
maps = model.validation(database_img, query_img, args)
67+
for key in maps:
68+
print(("{}\t{}".format(key, maps[key])))
69+
70+
pprint(vars(args))

examples/dtq/train_val_script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
parser.add_argument('--img-model', default='alexnet', type=str)
3939
parser.add_argument('--model-weights', type=str,
40-
default='../../core/architecture/single_model/pretrained_model/reference_pretrain.npy')
40+
default='../../DeepHash/architecture/single_model/pretrained_model/reference_pretrain.npy')
4141
parser.add_argument('--finetune-all', default=True, type=bool)
4242
parser.add_argument('--max-iter-update-b', default=3, type=int)
4343
parser.add_argument('--max-iter-update-Cb', default=1, type=int)

examples/prunehash/train_val.sh

-21
This file was deleted.

0 commit comments

Comments
 (0)