Skip to content

Commit

Permalink
✨ Start multivariate regression
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jul 1, 2023
1 parent 1843fa9 commit f7939a8
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions torch_uncertainty/routines/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def training_step(
logits = self.forward(inputs)

if self.dist_estimation:
means = logits[:, 0]
vars = F.softplus(logits[:, 1])
means = logits[..., 0]
vars = F.softplus(logits[..., 1])
loss = self.criterion(means, targets, vars)
else:
loss = self.criterion(logits, targets)
Expand All @@ -96,8 +96,8 @@ def validation_step(
inputs, targets = batch
logits = self.forward(inputs)
if self.dist_estimation:
means = logits[:, 0]
vars = F.softplus(logits[:, 1])
means = logits[..., 0]
vars = F.softplus(logits[..., 1])
self.val_metrics.gnll.update(means, targets, vars)
else:
means = logits
Expand All @@ -118,8 +118,8 @@ def test_step(
inputs, targets = batch
logits = self.forward(inputs)
if self.dist_estimation:
means = logits[:, 0]
vars = F.softplus(logits[:, 1])
means = logits[..., 0]
vars = F.softplus(logits[..., 1])
self.test_metrics.gnll.update(means, targets, vars)
else:
means = logits
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
optimization_procedure: Any,
num_estimators: int,
mode: Literal["mean", "mixture"],
out_features: Optional[int] = 1,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -166,6 +167,7 @@ def __init__(

self.mode = mode
self.num_estimators = num_estimators
self.out_features = out_features

def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
Expand All @@ -181,16 +183,25 @@ def validation_step( # type: ignore
) -> None:
inputs, targets = batch
logits = self.forward(inputs)
logits = rearrange(
logits, "(m b) dist -> b m dist", m=self.num_estimators
)

if self.out_features == 1:
logits = rearrange(
logits, "(m b) dist -> b m dist", m=self.num_estimators
)
else:
logits = rearrange(
logits,
"(m b) (f dist) -> b f m dist",
m=self.num_estimators,
f=self.out_features,
)

if self.mode == "mean":
logits = logits.mean(dim=1)

if self.dist_estimation:
means = logits[:, 0]
vars = F.softplus(logits[:, 1])
means = logits[..., 0]
vars = F.softplus(logits[..., 1])
self.val_metrics.gnll.update(means, targets, vars)
else:
means = logits
Expand All @@ -205,21 +216,31 @@ def test_step(
) -> None:
if dataloader_idx != 0:
raise NotImplementedError(
"OOD detection not implemented yet. Raise an issue if needed."
"Regression OOD detection not implemented yet. Raise an issue "
"if needed."
)

inputs, targets = batch
logits = self.forward(inputs)
logits = rearrange(
logits, "(m b) dist -> b m dist", m=self.num_estimators
)

if self.out_features == 1:
logits = rearrange(
logits, "(m b) dist -> b m dist", m=self.num_estimators
)
else:
logits = rearrange(
logits,
"(m b) (f dist) -> b f m dist",
m=self.num_estimators,
f=self.out_features,
)

if self.mode == "mean":
logits = logits.mean(dim=1)

if self.dist_estimation:
means = logits[:, 0]
vars = F.softplus(logits[:, 1])
means = logits[..., 0]
vars = F.softplus(logits[..., 1])
self.test_metrics.gnll.update(means, targets, vars)
else:
means = logits
Expand Down

0 comments on commit f7939a8

Please sign in to comment.