Skip to content

Commit 115409f

Browse files
Zarjagentensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 520733955
1 parent 9d3aaa0 commit 115409f

File tree

6 files changed

+35
-6
lines changed

6 files changed

+35
-6
lines changed

official/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py

+12
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
163163
mask_head=mask_head,
164164
mask_sampler=mask_sampler_obj,
165165
mask_roi_aligner=mask_roi_aligner_obj,
166+
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
167+
cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
168+
min_level=model_config.min_level,
169+
max_level=model_config.max_level,
170+
num_scales=model_config.anchor.num_scales,
171+
aspect_ratios=model_config.anchor.aspect_ratios,
172+
anchor_size=model_config.anchor.anchor_size,
166173
outer_boxes_scale=model_config.outer_boxes_scale,
167174
use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks)
168175
return model
@@ -193,4 +200,9 @@ def build_model(self):
193200
if self.task_config.freeze_backbone:
194201
model.backbone.trainable = False
195202

203+
# Builds the model through warm-up call.
204+
dummy_images = tf.keras.Input(self.task_config.model.input_size)
205+
dummy_image_shape = tf.keras.layers.Input([2])
206+
_ = model(dummy_images, image_shape=dummy_image_shape, training=False)
207+
196208
return model

official/projects/panoptic/configs/panoptic_deeplab.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ class PanopticDeeplab(hyperparams.Config):
129129
norm_activation: common.NormActivation = common.NormActivation()
130130
backbone: backbones.Backbone = backbones.Backbone(
131131
type='resnet', resnet=backbones.ResNet())
132-
decoder: decoders.Decoder = decoders.Decoder(type='aspp')
132+
decoder: decoders.Decoder = decoders.Decoder(
133+
type='aspp', aspp=decoders.ASPP(level=3))
133134
semantic_head: SemanticHead = SemanticHead()
134135
instance_head: InstanceHead = InstanceHead()
135136
shared_decoder: bool = False

official/projects/panoptic/configs/panoptic_maskrcnn.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_COCO_VAL_EXAMPLES = 5000
3939

4040
# pytype: disable=wrong-keyword-args
41+
# pylint: disable=unexpected-keyword-arg
4142

4243

4344
@dataclasses.dataclass
@@ -108,10 +109,9 @@ class Backbone(backbones.Backbone):
108109
@dataclasses.dataclass
109110
class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
110111
"""Panoptic Mask R-CNN model config."""
111-
backbone: Backbone = Backbone()
112-
segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
113-
SEGMENTATION_MODEL(num_classes=2))
114-
include_mask = True
112+
backbone: Backbone = Backbone(type='resnet', resnet=backbones.ResNet())
113+
segmentation_model: SEGMENTATION_MODEL = SEGMENTATION_MODEL(num_classes=2)
114+
include_mask: bool = True
115115
shared_backbone: bool = True
116116
shared_decoder: bool = True
117117
stuff_classes_offset: int = 0

official/projects/panoptic/tasks/panoptic_deeplab.py

+6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def build_model(self):
5050
input_specs=input_specs,
5151
model_config=self.task_config.model,
5252
l2_regularizer=l2_regularizer)
53+
54+
# Builds the model through warm-up call.
55+
dummy_images = tf.keras.Input(self.task_config.model.input_size)
56+
# Note that image_info is always in the shape of [4, 2].
57+
dummy_image_info = tf.keras.layers.Input([4, 2])
58+
_ = model(dummy_images, dummy_image_info, training=False)
5359
return model
5460

5561
def initialize(self, model: tf.keras.Model):

official/projects/panoptic/tasks/panoptic_maskrcnn.py

+6
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def build_model(self) -> tf.keras.Model:
7171
if self.task_config.freeze_backbone:
7272
model.backbone.trainable = False
7373

74+
# Builds the model through warm-up call.
75+
dummy_images = tf.keras.Input(self.task_config.model.input_size)
76+
# Note that image_info is always in the shape of [4, 2].
77+
dummy_image_info = tf.keras.layers.Input([4, 2])
78+
_ = model(dummy_images, image_info=dummy_image_info, training=False)
79+
7480
return model
7581

7682
def initialize(self, model: tf.keras.Model) -> None:

official/vision/tasks/maskrcnn.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def build_model(self):
9191
if self.task_config.freeze_backbone:
9292
model.backbone.trainable = False
9393

94+
# Builds the model through warm-up call.
95+
dummy_images = tf.keras.Input(self.task_config.model.input_size)
96+
dummy_image_shape = tf.keras.layers.Input([2])
97+
_ = model(dummy_images, image_shape=dummy_image_shape, training=False)
98+
9499
return model
95100

96101
def initialize(self, model: tf.keras.Model):
@@ -487,7 +492,6 @@ def validation_step(self,
487492
A dictionary of logs.
488493
"""
489494
images, labels = inputs
490-
491495
outputs = model(
492496
images,
493497
anchor_boxes=labels['anchor_boxes'],

0 commit comments

Comments
 (0)