-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
172 lines (131 loc) · 7.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""TODO: Edit."""
from __future__ import absolute_import
from __future__ import division
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
import representation_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('workdir', 'work-dir/', 'Where to store files.')
flags.DEFINE_integer('num_gpus', 1, 'Number of GPUs to use.')
flags.DEFINE_string('tpu_name', None, 'Name of the TPU node to use.')
flags.DEFINE_bool('run_eval', False, 'Run eval mode')
flags.DEFINE_string('checkpoint', None, 'TODO')
flags.DEFINE_string('path_to_initial_ckpt', None, 'TODO')
flags.DEFINE_string('vars_to_restore', '.*', 'TODO')
flags.DEFINE_integer('eval_timeout_mins', 20, 'TODO')
flags.DEFINE_bool('use_summaries', True, 'TODO')
flags.DEFINE_string('dataset', 'imagenet', 'Which dataset to use: `imagenet`')
flags.DEFINE_string('dataset_dir', None, 'Location of the dataset files.')
flags.mark_flag_as_required('dataset_dir')
flags.DEFINE_string('pseudo_label_key', None, 'Key used to retrieve pseudo labels.')
flags.DEFINE_string('original_label_key', None,
'Key used to retrieve original labels if available.')
flags.DEFINE_bool('cache_dataset', False, 'Whether to cache the dataset after'
'filtering. When using 10% of ImageNet, caching somehow makes'
' cloud-TPU crash (OOM), and we need to disable caching. But'
' when running locally, or when running on TPU with 1%, this'
' is necessary for decent speed: 10x speed-up on 1% TPU.')
flags.DEFINE_string('filename_list_template', None, 'TODO')
flags.DEFINE_integer('num_supervised_examples', None, 'TODO')
flags.DEFINE_float('unsup_batch_mult', None, 'TODO')
flags.DEFINE_integer('enable_sup_data', 1, 'Use supervised data.')
flags.DEFINE_integer('rot_loss_sup', 1, 'Enable rotation loss for sup. data.')
flags.DEFINE_integer('rot_loss_unsup', 1, 'Enable rotation loss for unsup. data.')
flags.DEFINE_float('rot_loss_weight', 1.0, 'Weight of the rotation loss.')
flags.DEFINE_integer('triplet_loss_sup', 1, 'Enable triplet loss for sup. data.')
flags.DEFINE_integer('triplet_loss_unsup', 1, 'Enable triplet loss for unsup. data.')
flags.DEFINE_float('triplet_loss_weight', 1.0, 'Weight of the triplet loss.')
flags.DEFINE_integer('save_checkpoints_steps', 1000, 'Every how many steps '
'to save a checkpoint. Defaults to 1000.')
flags.DEFINE_string('serving_input_key', 'image', 'The name of the input tensor'
' in the generated hub module. Just leave it at default.')
flags.DEFINE_string('serving_input_shape', 'None,None,None,3', 'The shape of '
'the input tensor in the stored hub module.')
flags.DEFINE_integer('random_seed', None, 'Seed to use. None is random.')
flags.DEFINE_string('task', None, 'Which pretext-task to learn from. Can be '
'one of `rotation`, `exemplar`, `jigsaw`, '
'`relative_patch_location`, `linear_eval`, `supervised`.')
flags.mark_flag_as_required('task')
flags.DEFINE_string('train_split', 'train', 'Which dataset split to train on. '
'Should only be `train` (default) or `trainval`.')
flags.DEFINE_string('val_split', 'val', 'Which dataset split to eval on. '
'Should only be `val` (default) or `test`.')
flags.DEFINE_string('test_split', None, 'Optionally evaluate the last '
'checkpoint on this split.')
flags.DEFINE_integer('batch_size', None, 'The global batch-size to use.')
flags.mark_flag_as_required('batch_size')
flags.DEFINE_integer('eval_batch_size', None, 'Optional different batch-size'
' evaluation, defaults to the same as `batch_size`.')
flags.DEFINE_string('preprocessing', None, 'A comma-separated list of '
'pre-processing steps to perform on unlabeled data, '
'see preprocess.py.')
flags.mark_flag_as_required('preprocessing')
flags.DEFINE_string('sup_preprocessing', None, 'A comma-separated list of '
'pre-processing steps to perform on labeled data, '
'see preprocess.py.')
flags.mark_flag_as_required('sup_preprocessing')
flags.DEFINE_string('preprocessing_eval', None, 'Optionally, a different pre-'
'processing for the unlabelled data during evaluation.')
flags.DEFINE_string('sup_preprocessing_eval', None, 'Optionally, a separate '
'preprocessing for the labelled data during evaluation.')
flags.DEFINE_string('schedule', None, 'Learning-rate decay schedule.')
flags.mark_flag_as_required('schedule')
flags.DEFINE_string('architecture', None,
help='Which basic network architecture to use. '
'One of vgg19, resnet50, revnet50.')
flags.DEFINE_integer('filters_factor', 4, 'Widening factor for network '
'filters. For ResNet, default = 4 = vanilla ResNet.')
flags.DEFINE_float('weight_decay', 1e-4, 'Strength of weight-decay. '
'Defaults to 1e-4, and may be set to 0.')
flags.DEFINE_bool('polyak_averaging', False, 'If true, use polyak averaging.')
flags.DEFINE_float('lr', None, 'The base learning-rate to use for training.')
flags.mark_flag_as_required('lr')
flags.DEFINE_float('lr_scale_batch_size', None, 'The batch-size for which the '
'base learning-rate `lr` is defined. For batch-sizes '
'different from that, it is scaled linearly accordingly.'
'For example lr=0.1, batch_size=128, lr_scale_batch_size=32'
', then actual lr=0.025.')
flags.mark_flag_as_required('lr_scale_batch_size')
flags.DEFINE_float('lr_decay_factor', 0.1, 'Factor by which to decay the '
'learning-rate at each decay step. Default 0.1.')
flags.DEFINE_string('optimizer', 'sgd', 'Which optimizer to use. '
'Only `sgd` (default) or `adam` are supported.')
flags.DEFINE_integer('triplet_embed_dim', 1000, 'Size of the embedding for the '
'triple loss.')
flags.DEFINE_float('sup_weight', 0.3, 'Only for MOAM: Weight of supervised loss')
# for VAT:
flags.DEFINE_float('vat_weight', 0.0, 'Weight multiplier for VAT loss')
flags.DEFINE_float('entmin_factor', 0.0, 'Weight multiplier for EntMin loss')
flags.DEFINE_float('vat_eps', 1.0, 'epsilon used for finite difference '
'approximation used in VAT')
flags.DEFINE_integer('vat_num_power_method_iters', 1, 'Number of power method '
'iterations used in VAT to approximate top eigenvalue')
flags.DEFINE_boolean('apply_vat_to_labeled', False, 'Apply VAT loss also to '
'labeled examples?')
def main(unused_argv):
del unused_argv # unused
logging.info('workdir: %s', FLAGS.workdir)
# Write a json file to the cns (and log) which contains the run's settings.
if not FLAGS.run_eval and not tf.gfile.IsDirectory(FLAGS.workdir):
tf.gfile.MakeDirs(FLAGS.workdir)
results = representation_lib.train_and_eval()
if results:
logging.info('Result: %s', results)
logging.info('I\'m done with my work, ciao!')
if __name__ == '__main__':
app.run(main)