Skip to content

Commit

Permalink
Resolve MPS's lack of cummax
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 3658f95 commit 9a45edc
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -234,7 +235,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
Returns:
average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
"""
mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", False)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = True
precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback
rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1))
rec_thresh_indices = (
torch.searchsorted(recall, rec_thresholds)
Expand Down

0 comments on commit 9a45edc

Please # to comment.