Skip to content

Commit

Permalink
improve documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
antonbaumann committed May 2, 2024
1 parent 73bdf9b commit 2ad877d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions utools/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from utools.wrappers.ensemble import Ensemble

__all__ = [
'BaseWrapper',
'MonteCarlo',
'Ensemble',
'MonteCarlo',
'BaseWrapper',
]
4 changes: 2 additions & 2 deletions utools/wrappers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
criterion (RegressionLoss | HeteroscedasticSoftmax): The criterion to be used for computing probabilistic outputs.
"""
super(Ensemble, self).__init__()
self.wrapper = BaseWrapper(models=models, criterion=criterion, monte_carlo_samples=1)
self.__wrapper = BaseWrapper(models=models, criterion=criterion, monte_carlo_samples=1)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -37,4 +37,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The ensemble prediction dictionary.
"""
return self.wrapper(input)
return self.__wrapper(input)
6 changes: 3 additions & 3 deletions utools/wrappers/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ def __init__(
monte_carlo_samples (int): The number of samples to use for Monte Carlo simulation.
"""
super(MonteCarlo, self).__init__()
self.wrapper = BaseWrapper(models=[model], criterion=criterion, monte_carlo_samples=monte_carlo_samples)
self.__wrapper = BaseWrapper(models=[model], criterion=criterion, monte_carlo_samples=monte_carlo_samples)

# Activate MC Dropout for the model
for model in self.wrapper.models:
for model in self.__wrapper.models:
self._activate_mc_dropout(model)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Computes the Monte Carlo prediction.
"""
return self.wrapper(input)
return self.__wrapper(input)

@staticmethod
def _activate_mc_dropout(model: torch.nn.Module):
Expand Down

0 comments on commit 2ad877d

Please # to comment.