Skip to content

Commit

Permalink
🎨 Move AUSE out of classification metrics as it is a general metric
Browse files Browse the repository at this point in the history
- fix docstrings
  • Loading branch information
alafage committed Sep 18, 2024
1 parent eb72501 commit 58acaf8
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 6 deletions.
13 changes: 12 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ Others
:nosignatures:
:template: class.rst

AUSE
GroupingLoss

Regression
Expand Down Expand Up @@ -294,6 +293,18 @@ Segmentation

MeanIntersectionOverUnion

Others
^^^^^^

.. currentmodule:: torch_uncertainty.metrics

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

AUSE

Losses
------

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion torch_uncertainty/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .classification import (
AUGRC,
AURC,
AUSE,
FPR95,
AdaptiveCalibrationError,
BrierScore,
Expand All @@ -29,3 +28,4 @@
SILog,
ThresholdAccuracy,
)
from .sparsification import AUSE
1 change: 0 additions & 1 deletion torch_uncertainty/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
RiskAt80Cov,
RiskAtxCov,
)
from .sparsification import AUSE
from .variation_ratio import VariationRatio
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, **kwargs) -> None:
Inputs:
- :attr:`scores`: Uncertainty scores of shape :math:`(B,)`. A higher
score means a higher uncertainty.
- :attr:`errors`: Binary errors of shape :math:`(B,)`,
- :attr:`errors`: Errors of shape :math:`(B,)`,
where :math:`B` is the batch size.
Expand All @@ -52,7 +52,7 @@ def update(self, scores: Tensor, errors: Tensor) -> None:
Args:
scores (Tensor): uncertainty scores of shape :math:`(B,)`
errors (Tensor): binary errors of shape :math:`(B,)`
errors (Tensor): errors of shape :math:`(B,)`
"""
self.scores.append(scores)
self.errors.append(errors)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _ause_rejection_rate_compute(
Args:
scores (Tensor): uncertainty scores of shape :math:`(B,)`
errors (Tensor): binary errors of shape :math:`(B,)`
errors (Tensor): errors of shape :math:`(B,)`
"""
num_samples = errors.size(0)

Expand Down

0 comments on commit 58acaf8

Please sign in to comment.