@@ -163,6 +163,13 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
163
163
mask_head = mask_head ,
164
164
mask_sampler = mask_sampler_obj ,
165
165
mask_roi_aligner = mask_roi_aligner_obj ,
166
+ class_agnostic_bbox_pred = detection_head_config .class_agnostic_bbox_pred ,
167
+ cascade_class_ensemble = detection_head_config .cascade_class_ensemble ,
168
+ min_level = model_config .min_level ,
169
+ max_level = model_config .max_level ,
170
+ num_scales = model_config .anchor .num_scales ,
171
+ aspect_ratios = model_config .anchor .aspect_ratios ,
172
+ anchor_size = model_config .anchor .anchor_size ,
166
173
outer_boxes_scale = model_config .outer_boxes_scale ,
167
174
use_gt_boxes_for_masks = model_config .use_gt_boxes_for_masks )
168
175
return model
@@ -193,4 +200,9 @@ def build_model(self):
193
200
if self .task_config .freeze_backbone :
194
201
model .backbone .trainable = False
195
202
203
+ # Builds the model through warm-up call.
204
+ dummy_images = tf .keras .Input (self .task_config .model .input_size )
205
+ dummy_image_shape = tf .keras .layers .Input ([2 ])
206
+ _ = model (dummy_images , image_shape = dummy_image_shape , training = False )
207
+
196
208
return model
0 commit comments