Skip to content

Commit

Permalink
Fix a bug related to MPS and torch.double
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 4038c2b commit 4b6afdd
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
)
)

self._scores.append(pred["scores"][max_best_detections_index].to(self._device))
scores = pred["scores"][max_best_detections_index]
if self._device == torch.device("mps") and scores.dtype == torch.double:
scores = scores.to(dtype=torch.float32)
self._scores.append(scores.to(self._device))
self._y_pred_labels.append(pred_labels.to(device=self._device))

@sync_all_reduce("_y_true_count")
Expand Down

0 comments on commit 4b6afdd

Please # to comment.