diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index fc38cfd8d5f..47a06dca09f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -345,7 +345,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ) ) - self._scores.append(pred["scores"][max_best_detections_index].to(self._device)) + scores = pred["scores"][max_best_detections_index] + if self._device == torch.device("mps") and scores.dtype == torch.double: + scores = scores.to(dtype=torch.float32) + self._scores.append(scores.to(self._device)) self._y_pred_labels.append(pred_labels.to(device=self._device)) @sync_all_reduce("_y_true_count")