Skip to content

Commit

Permalink
🎨 Continue improve typing & misc
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jul 4, 2023
1 parent d1354ba commit a2d44b8
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 25 deletions.
22 changes: 19 additions & 3 deletions experiments/regression/uci_datasets/kin8nm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@
from pathlib import Path

import torch.nn as nn
import torch.optim as optim

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines.regression.mlp import MLP
from torch_uncertainty.datamodules.uci_regression import UCIDataModule
from torch_uncertainty.optimization_procedures import optim_regression


# fmt: on
def optim_regression(
model: nn.Module,
learning_rate: float = 5e-3,
) -> dict:
optimizer = optim.Adam(
model.parameters(),
lr=learning_rate,
weight_decay=0,
)
return {
"optimizer": optimizer,
}


if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(MLP, UCIDataModule)

net_name = "mlp-10neurons-2layers-kin8nm"
net_name = "mlp-kin8nm"

# datamodule
args.root = str(root / "data")
Expand All @@ -24,9 +39,10 @@
model = MLP(
num_outputs=2,
in_features=8,
hidden_dims=[100],
loss=nn.GaussianNLLLoss,
optimization_procedure=optim_regression,
**vars(args),
)

cls_main(model, dm, root, net_name, "regression", args)
cls_main(model, dm, root, net_name, args)
8 changes: 4 additions & 4 deletions torch_uncertainty/baselines/classification/resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# fmt: off
from argparse import ArgumentParser, BooleanOptionalAction
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -142,16 +142,16 @@ def __new__(
cls,
num_classes: int,
in_channels: int,
loss: nn.Module,
loss: Type[nn.Module],
optimization_procedure: Any,
version: Literal["vanilla", "packed", "batched", "masked"],
arch: int,
style: str = "imagenet",
num_estimators: Optional[int] = None,
groups: Optional[int] = 1,
groups: int = 1,
scale: Optional[float] = None,
alpha: Optional[float] = None,
gamma: Optional[int] = 1,
gamma: int = 1,
use_entropy: bool = False,
use_logits: bool = False,
use_mi: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions torch_uncertainty/baselines/classification/vgg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# fmt: off
from argparse import ArgumentParser
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Type

import torch.nn as nn
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -100,15 +100,15 @@ def __new__(
cls,
num_classes: int,
in_channels: int,
loss: nn.Module,
loss: Type[nn.Module],
optimization_procedure: Any,
version: Literal["vanilla", "packed"],
arch: int,
num_estimators: Optional[int] = None,
style: str = "imagenet",
groups: Optional[int] = 1,
groups: int = 1,
alpha: Optional[float] = None,
gamma: Optional[int] = 1,
gamma: int = 1,
use_entropy: bool = False,
use_logits: bool = False,
use_mi: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/baselines/classification/wideresnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# fmt: off
from argparse import ArgumentParser, BooleanOptionalAction
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Type

import torch.nn as nn
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -92,7 +92,7 @@ def __new__(
cls,
num_classes: int,
in_channels: int,
loss: nn.Module,
loss: Type[nn.Module],
optimization_procedure: Any,
version: Literal["vanilla", "packed", "batched", "masked"],
style: str = "imagenet",
Expand Down
8 changes: 7 additions & 1 deletion torch_uncertainty/baselines/deep_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ class DeepEnsembles:

def __new__(
cls,
in_channels: int,
num_classes: int,
task: Literal["classification", "regression"],
log_path: Union[str, Path],
versions: List[int],
backbone: Literal["resnet"],
num_estimators: int,
use_entropy: bool = False,
use_logits: bool = False,
use_mi: bool = False,
Expand Down Expand Up @@ -48,9 +51,12 @@ def __new__(

if task == "classification":
return ClassificationEnsemble(
in_channels=in_channels,
num_classes=num_classes,
model=de,
loss=None,
loss=None, # TODO: Why None? We won't support training?
optimization_procedure=None,
num_estimators=num_estimators,
use_entropy=use_entropy,
use_logits=use_logits,
use_mi=use_mi,
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/datamodules/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
batch_size: int,
val_split: float = 0.0,
num_workers: int = 1,
cutout: int = Optional[None],
cutout: Optional[int] = None,
auto_augment: Optional[str] = None,
test_alt: Optional[Literal["c", "h"]] = None,
corruption_severity: int = 1,
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def __init__(
batch_size: int,
val_split: float = 0.0,
num_workers: int = 1,
cutout: int = None,
cutout: Optional[int] = None,
enable_randaugment: bool = False,
auto_augment: str = None,
auto_augment: Optional[str] = None,
test_alt: Optional[Literal["c"]] = None,
corruption_severity: int = 1,
num_dataloaders: int = 1,
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/datamodules/tiny_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
root: Union[str, Path],
ood_detection: bool,
batch_size: int,
rand_augment_opt: str = None,
rand_augment_opt: Optional[str] = None,
num_workers: int = 1,
pin_memory: bool = True,
persistent_workers: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/datamodules/uci_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
root: Union[str, Path],
batch_size: int,
dataset_name: str,
val_split: Optional[float] = 0.0,
val_split: float = 0.0,
num_workers: int = 1,
pin_memory: bool = True,
persistent_workers: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/datasets/aggregated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self, dataset: Dataset, n_dataloaders: int) -> None:
super().__init__()
self.dataset = dataset
self.n_dataloaders = n_dataloaders
self.dataset_size = len(self.dataset)
self.dataset_size = len(dataset)
self.offset = self.dataset_size // self.n_dataloaders

def __getitem__(self, idx: int):
Expand Down
8 changes: 4 additions & 4 deletions torch_uncertainty/datasets/cifar/cifar_h.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# fmt:off
import os
from typing import Any, Callable
from typing import Any, Callable, Optional

import torch
from torchvision.datasets import CIFAR10
Expand Down Expand Up @@ -37,9 +37,9 @@ class CIFAR10_H(CIFAR10):
def __init__(
self,
root: str,
train: bool = None,
transform: Callable[..., Any] = None,
target_transform: Callable[..., Any] = None,
train: Optional[bool] = None,
transform: Optional[Callable[..., Any]] = None,
target_transform: Optional[Callable[..., Any]] = None,
download: bool = False,
) -> None:
print(
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/post_processing/temperature_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def fit(
[self.temperature], lr=self.lr, max_iter=self.max_iter
)

def eval() -> torch.Tensor:
def eval() -> float:
optimizer.zero_grad()
loss = self.criterion(self._scale(logits), labels)
loss.backward()
Expand Down

0 comments on commit a2d44b8

Please sign in to comment.