From 0444933caad1014495a182fcd3cd251eeff03e11 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 13:15:36 +0330 Subject: [PATCH] Resolve MPS's lack of cummax --- .../vision/object_detection_average_precision_recall.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index a56a18c164f..3a7ae52f6d8 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,3 +1,4 @@ +import os from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -125,7 +126,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo class_mean=None, ) precision = torch.double if torch.device(device).type != "mps" else torch.float32 - self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision) + self.rec_thresholds = cast(torch.Tensor, self.rec_thresholds).to(device=device, dtype=precision) @reinit__is_reduced def reset(self) -> None: @@ -234,7 +235,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ + mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = ( torch.searchsorted(recall, rec_thresholds)