Skip to content

Commit f4d05e1

Browse files
Bycobmergify[bot]
authored andcommittedApr 9, 2021
feat(torch): add more backbones to traced detection models
1 parent 2e62b7e commit f4d05e1

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed
 

‎tools/torch/trace_torchvision.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
parser = argparse.ArgumentParser(description="Trace image processing models from torchvision")
4040
parser.add_argument('models', type=str, nargs='*', help="Models to trace.")
41+
parser.add_argument('--backbone', type=str, help="Backbone for detection models")
4142
parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit")
4243
parser.add_argument('--to-dd-native', action='store_true', help="Prepare the model so that the weights can be loaded on native model with dede")
4344
parser.add_argument('-a', "--all", action='store_true', help="Export all available models")
@@ -139,7 +140,12 @@ def get_detection_input():
139140
"resnext101_32x8d": M.resnext101_32x8d,
140141
}
141142
detection_model_classes = {
143+
"fasterrcnn": M.detection.FasterRCNN,
142144
"fasterrcnn_resnet50_fpn": M.detection.fasterrcnn_resnet50_fpn,
145+
"fasterrcnn_mobilenet_v3_large_fpn": M.detection.fasterrcnn_mobilenet_v3_large_fpn,
146+
"fasterrcnn_mobilenet_v3_large_320_fpn": M.detection.fasterrcnn_mobilenet_v3_large_320_fpn,
147+
148+
"retinanet": M.detection.RetinaNet,
143149
"retinanet_resnet50_fpn": M.detection.retinanet_resnet50_fpn,
144150
}
145151
model_classes.update(detection_model_classes)
@@ -170,22 +176,41 @@ def get_detection_input():
170176
detection = mname in detection_model_classes
171177

172178
if detection:
173-
model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose)
174-
175-
if args.num_classes:
176-
logging.info("Using num_classes = %d" % args.num_classes)
177-
178-
if "fasterrcnn" in mname:
179-
# get number of input features for the classifier
180-
in_features = model.roi_heads.box_predictor.cls_score.in_features
181-
# replace the pre-trained head with a new one
182-
model.roi_heads.box_predictor = M.detection.faster_rcnn.FastRCNNPredictor(in_features, args.num_classes)
183-
elif "retinanet" in mname:
184-
in_channels = model.backbone.out_channels
185-
num_anchors = model.head.classification_head.num_anchors
186-
# replace pretrained head - does not work
187-
# model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
188-
raise Exception("Retinanet with fixed number of classes is not yet supported")
179+
if mname in ["fasterrcnn", "retinanet"]:
180+
if args.backbone and args.backbone in model_classes:
181+
if "resnet" in args.backbone or "resnext" in args.backbone:
182+
backbone = M.detection.backbone_utils.resnet_fpn_backbone(args.backbone, pretrained = args.pretrained)
183+
elif "mobilenet" in args.backbone:
184+
backbone = M.detection.backbone_utils.mobilenet_backbone(args.backbone, pretrained = args.pretrained, fpn = True)
185+
else:
186+
raise RuntimeError("Backbone not supported: %s. Supported backbones are resnet, resnext or mobilenet." % args.backbone)
187+
else:
188+
raise RuntimeError("Please specify a backbone for model %s" % mname)
189+
190+
if args.pretrained:
191+
logging.warn("Pretrained models are not available for custom backbones. " +
192+
"Output model (except the backbone) will be untrained.")
193+
194+
model = model_classes[mname](backbone, args.num_classes)
195+
else:
196+
if args.backbone:
197+
raise RuntimeError("--backbone is only supported with models \"fasterrcnn\" or \"retinanet\".")
198+
model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose)
199+
200+
if args.num_classes:
201+
logging.info("Using num_classes = %d" % args.num_classes)
202+
203+
if "fasterrcnn" in mname:
204+
# get number of input features for the classifier
205+
in_features = model.roi_heads.box_predictor.cls_score.in_features
206+
# replace the pre-trained head with a new one
207+
model.roi_heads.box_predictor = M.detection.faster_rcnn.FastRCNNPredictor(in_features, args.num_classes)
208+
elif "retinanet" in mname:
209+
in_channels = model.backbone.out_channels
210+
num_anchors = model.head.classification_head.num_anchors
211+
# replace pretrained head - does not work
212+
# model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
213+
raise Exception("Retinanet with fixed number of classes is not yet supported")
189214

190215
model.eval()
191216
detect_model = DetectionModel(model)

0 commit comments

Comments
 (0)