Skip to content

Commit

Permalink
🔨 Rework stochastic models
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jul 1, 2023
1 parent 84449a3 commit b4a42b7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 2 additions & 2 deletions torch_uncertainty/models/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..layers.bayesian_layers import BayesConv2d, BayesLinear
from ..layers.packed_layers import PackedConv2d, PackedLinear
from .variational_model import Variational
from .utils import Stochastic


class LeNet(nn.Module):
Expand Down Expand Up @@ -41,7 +41,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out


@Variational
@Stochastic
class BayesianLeNet(LeNet):
pass

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from ..layers.bayesian_layers import bayesian_modules


def Variational(Model):
def Stochastic(Model):
"""Decorator for stochastic models. When applied to a model, it adds the
freeze and unfreeze methods to the model. Use freeze to freeze the
stochastic layers and obtain deterministic outputs. Use unfreeze to
unfreeze the stochastic layers and obtain stochastic outputs.
"""

def freeze(self):
for module in self.modules():
if isinstance(module, bayesian_modules):
Expand Down

0 comments on commit b4a42b7

Please sign in to comment.