Skip to content

Commit cbbbd99

Browse files
Bycobmergify[bot]
authored andcommitted
fix(torch): fix faster rcnn model export for training
1 parent 7d292db commit cbbbd99

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tools/torch/trace_torchvision.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(self, x, bboxes = None, labels = None):
9191
# Sum of all losses for finetuning (as done in vision/references/detection/engine.py)
9292
losses = [l for l in losses.values()]
9393
loss = torch.zeros((1,), device=x.device, dtype=x.dtype)
94-
for i in range(1, len(losses)):
94+
for i in range(len(losses)):
9595
loss += losses[i]
9696
else:
9797
losses, predictions = self.model(l_x)
@@ -212,7 +212,6 @@ def get_detection_input():
212212
# model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
213213
raise Exception("Retinanet with fixed number of classes is not yet supported")
214214

215-
model.eval()
216215
detect_model = DetectionModel(model)
217216
detect_model.train()
218217
script_module = torch.jit.script(detect_model)

0 commit comments

Comments
 (0)