Skip to content

Commit 351d6c6

Browse files
Bycobmergify[bot]
authored andcommitted
fix(torch): retinanet now trains correctly
Was actually fixed in #1265, removed the exception in the export script
1 parent 91bde66 commit 351d6c6

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tools/torch/trace_torchvision.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,8 @@ def get_detection_input():
226226
elif "retinanet" in mname:
227227
in_channels = model.backbone.out_channels
228228
num_anchors = model.head.classification_head.num_anchors
229-
# replace pretrained head - does not work
230-
# model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
231-
raise Exception("Retinanet with fixed number of classes is not yet supported")
229+
# replace pretrained head
230+
model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
232231

233232
detect_model = DetectionModel(model)
234233
detect_model.train()

0 commit comments

Comments
 (0)