Skip to content

Commit

Permalink
fixed the parse_reduction to take values from min, max, median
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 14, 2023
1 parent cbd3758 commit 2b1f9ab
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
13 changes: 10 additions & 3 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,16 +599,23 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
elif reduction == "sum":
return torch.sum
elif reduction == "max":
return torch.max
def max_vals(x, dim):
return torch.max(x, dim=dim).values
return max_vals
elif reduction == "min":
return torch.min
def min_vals(x, dim):
return torch.min(x, dim=dim).values
return min_vals
elif reduction == "median":
return torch.median
def median_vals(x, dim):
return torch.median(x, dim=dim).values
return median_vals
elif callable(reduction):
return reduction
else:
raise ValueError(f"Unknown reduction {reduction}")


def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT
Expand Down
12 changes: 9 additions & 3 deletions graphium/nn/ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,17 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
elif reduction == "sum":
return torch.sum
elif reduction == "max":
return torch.max
def max_vals(x, dim):
return torch.max(x, dim=dim).values
return max_vals
elif reduction == "min":
return torch.min
def min_vals(x, dim):
return torch.min(x, dim=dim).values
return min_vals
elif reduction == "median":
return torch.median
def median_vals(x, dim):
return torch.median(x, dim=dim).values
return median_vals
elif callable(reduction):
return reduction
else:
Expand Down

0 comments on commit 2b1f9ab

Please # to comment.