From 4b6afdd4fa3b8b74dbc5156ac7a3978ee6f099d1 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 04:22:02 +0330 Subject: [PATCH] Fix a bug related to MPS and torch.double --- .../vision/object_detection_average_precision_recall.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index fc38cfd8d5f..47a06dca09f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -345,7 +345,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ) ) - self._scores.append(pred["scores"][max_best_detections_index].to(self._device)) + scores = pred["scores"][max_best_detections_index] + if self._device == torch.device("mps") and scores.dtype == torch.double: + scores = scores.to(dtype=torch.float32) + self._scores.append(scores.to(self._device)) self._y_pred_labels.append(pred_labels.to(device=self._device)) @sync_all_reduce("_y_true_count")