Skip to content

Commit

Permalink
added more options to the parse_reduction in EnsembleMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 14, 2023
1 parent dd157c4 commit cbd3758
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions graphium/nn/ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ def __init__(
- "sum": Sum reduction
- "max": Max reduction
- "min": Min reduction
- "median": Median reduction
- `Callable`: Any callable function. Must take `dim` as a keyword argument.
activation:
activation:
Activation function to use in all the layers except the last.
if `layers==1`, this parameter is ignored
last_activation:
Expand Down Expand Up @@ -389,7 +390,12 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
return torch.mean
elif reduction == "sum":
return torch.sum

elif reduction == "max":
return torch.max
elif reduction == "min":
return torch.min
elif reduction == "median":
return torch.median
elif callable(reduction):
return reduction
else:
Expand Down

0 comments on commit cbd3758

Please # to comment.