diff --git a/jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py b/jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py index 308c8739b6..81f0d959d5 100644 --- a/jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py +++ b/jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py @@ -31,12 +31,20 @@ def test_model_params(): model_parameters = {"Gaussian1D": ["amplitude", "stddev", "mean"], "Const1D": ["amplitude"], "Linear1D": ["slope", "intercept"], + "Polynomial1D": ["c0", "c1"], "PowerLaw1D": ["amplitude", "x_0", "alpha"], "Lorentz1D": ["amplitude", "x_0", "fwhm"], "Voigt1D": ["x_0", "amplitude_L", "fwhm_L", "fwhm_G"], + "BlackBody": ["temperature", "scale"], } - for model_name, expected_params in model_parameters.items(): + for model_name in initializers.MODELS.keys(): + if model_name not in model_parameters.keys(): + # this would be caught later by the assertion anyways, + # but raising an error will be more clear that the + # test needs to be updated rather than the code breaking + raise ValueError(f"{model_name} not in test dictionary of expected parameters") + expected_params = model_parameters.get(model_name, []) params = initializers.get_model_parameters(model_name) assert len(params) == len(expected_params) assert np.all([p in expected_params for p in params])