Skip to content

Commit

Permalink
[cm] Small bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Mar 4, 2024
1 parent fe90dd7 commit 7f8d2ac
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions neutone_sdk/non_realtime_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:

self.n_numerical_params = self.n_cont_params + self.n_cat_params

assert self.get_default_param_values().size(0) == self.n_numerical_params, (
f"Default parameter values tensor first dimension must have the same "
f"size as the number of numerical parameters. Expected size "
f"{self.n_numerical_params}, got {self.get_default_param_values().size(0)}"
)
assert self.n_numerical_params <= constants.NEUTONE_GEN_N_NUMERICAL_PARAMS, (
f"Too many numerical (continuous and categorical) parameters. "
f"Max allowed is {constants.NEUTONE_GEN_N_NUMERICAL_PARAMS}"
Expand Down Expand Up @@ -146,10 +151,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
]

# TODO(cm): this statement will also be removed once core is refactored
assert (
len(self.get_default_param_names())
== constants.NEUTONE_GEN_N_NUMERICAL_PARAMS
)
assert len(self.get_default_param_names()) == self.n_numerical_params

assert all(
1 <= n <= 2 for n in self.get_audio_in_channels()
Expand Down Expand Up @@ -194,7 +196,6 @@ def _create_default_param_values(self) -> Tensor:
elif p.type == NeutoneParameterType.CATEGORICAL:
# Convert to float to match the type of the continuous parameters
numerical_default_values.append(float(p.default_value))
assert len(numerical_default_values) == self.n_numerical_params
numerical_default_values = tr.tensor(numerical_default_values)
return numerical_default_values

Expand Down

0 comments on commit 7f8d2ac

Please # to comment.