|
38 | 38 |
|
39 | 39 | parser = argparse.ArgumentParser(description="Trace image processing models from torchvision")
|
40 | 40 | parser.add_argument('models', type=str, nargs='*', help="Models to trace.")
|
| 41 | +parser.add_argument('--backbone', type=str, help="Backbone for detection models") |
41 | 42 | parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit")
|
42 | 43 | parser.add_argument('--to-dd-native', action='store_true', help="Prepare the model so that the weights can be loaded on native model with dede")
|
43 | 44 | parser.add_argument('-a', "--all", action='store_true', help="Export all available models")
|
@@ -139,7 +140,12 @@ def get_detection_input():
|
139 | 140 | "resnext101_32x8d": M.resnext101_32x8d,
|
140 | 141 | }
|
141 | 142 | detection_model_classes = {
|
| 143 | + "fasterrcnn": M.detection.FasterRCNN, |
142 | 144 | "fasterrcnn_resnet50_fpn": M.detection.fasterrcnn_resnet50_fpn,
|
| 145 | + "fasterrcnn_mobilenet_v3_large_fpn": M.detection.fasterrcnn_mobilenet_v3_large_fpn, |
| 146 | + "fasterrcnn_mobilenet_v3_large_320_fpn": M.detection.fasterrcnn_mobilenet_v3_large_320_fpn, |
| 147 | + |
| 148 | + "retinanet": M.detection.RetinaNet, |
143 | 149 | "retinanet_resnet50_fpn": M.detection.retinanet_resnet50_fpn,
|
144 | 150 | }
|
145 | 151 | model_classes.update(detection_model_classes)
|
@@ -170,22 +176,41 @@ def get_detection_input():
|
170 | 176 | detection = mname in detection_model_classes
|
171 | 177 |
|
172 | 178 | if detection:
|
173 |
| - model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose) |
174 |
| - |
175 |
| - if args.num_classes: |
176 |
| - logging.info("Using num_classes = %d" % args.num_classes) |
177 |
| - |
178 |
| - if "fasterrcnn" in mname: |
179 |
| - # get number of input features for the classifier |
180 |
| - in_features = model.roi_heads.box_predictor.cls_score.in_features |
181 |
| - # replace the pre-trained head with a new one |
182 |
| - model.roi_heads.box_predictor = M.detection.faster_rcnn.FastRCNNPredictor(in_features, args.num_classes) |
183 |
| - elif "retinanet" in mname: |
184 |
| - in_channels = model.backbone.out_channels |
185 |
| - num_anchors = model.head.classification_head.num_anchors |
186 |
| - # replace pretrained head - does not work |
187 |
| - # model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes) |
188 |
| - raise Exception("Retinanet with fixed number of classes is not yet supported") |
| 179 | + if mname in ["fasterrcnn", "retinanet"]: |
| 180 | + if args.backbone and args.backbone in model_classes: |
| 181 | + if "resnet" in args.backbone or "resnext" in args.backbone: |
| 182 | + backbone = M.detection.backbone_utils.resnet_fpn_backbone(args.backbone, pretrained = args.pretrained) |
| 183 | + elif "mobilenet" in args.backbone: |
| 184 | + backbone = M.detection.backbone_utils.mobilenet_backbone(args.backbone, pretrained = args.pretrained, fpn = True) |
| 185 | + else: |
| 186 | + raise RuntimeError("Backbone not supported: %s. Supported backbones are resnet, resnext or mobilenet." % args.backbone) |
| 187 | + else: |
| 188 | + raise RuntimeError("Please specify a backbone for model %s" % mname) |
| 189 | + |
| 190 | + if args.pretrained: |
| 191 | + logging.warn("Pretrained models are not available for custom backbones. " + |
| 192 | + "Output model (except the backbone) will be untrained.") |
| 193 | + |
| 194 | + model = model_classes[mname](backbone, args.num_classes) |
| 195 | + else: |
| 196 | + if args.backbone: |
| 197 | + raise RuntimeError("--backbone is only supported with models \"fasterrcnn\" or \"retinanet\".") |
| 198 | + model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose) |
| 199 | + |
| 200 | + if args.num_classes: |
| 201 | + logging.info("Using num_classes = %d" % args.num_classes) |
| 202 | + |
| 203 | + if "fasterrcnn" in mname: |
| 204 | + # get number of input features for the classifier |
| 205 | + in_features = model.roi_heads.box_predictor.cls_score.in_features |
| 206 | + # replace the pre-trained head with a new one |
| 207 | + model.roi_heads.box_predictor = M.detection.faster_rcnn.FastRCNNPredictor(in_features, args.num_classes) |
| 208 | + elif "retinanet" in mname: |
| 209 | + in_channels = model.backbone.out_channels |
| 210 | + num_anchors = model.head.classification_head.num_anchors |
| 211 | + # replace pretrained head - does not work |
| 212 | + # model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes) |
| 213 | + raise Exception("Retinanet with fixed number of classes is not yet supported") |
189 | 214 |
|
190 | 215 | model.eval()
|
191 | 216 | detect_model = DetectionModel(model)
|
|
0 commit comments