diff --git a/torch_uncertainty/metrics/classification/fpr95.py b/torch_uncertainty/metrics/classification/fpr95.py index 23dabd47..eb6bf66b 100644 --- a/torch_uncertainty/metrics/classification/fpr95.py +++ b/torch_uncertainty/metrics/classification/fpr95.py @@ -104,10 +104,10 @@ def compute(self) -> Tensor: false_pos = torch.cat( [ false_pos[: last_ind + 1].flip(0), - torch.tensor([0.0], dtype=self.dtype, device=self.device), + torch.tensor([0.0], device=self.device), ] ) - cutoff = torch.argmin(torch.abs(recall - 0.6)) + cutoff = torch.argmin(torch.abs(recall - self.recall_level)) return false_pos[cutoff] / (~labels).sum()