Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

NumpyroPlayerModel config validity in test_get_fitted_player_model_numpyro #578

Closed
griff-rees opened this issue Jun 12, 2023 · 2 comments · Fixed by #577
Closed

NumpyroPlayerModel config validity in test_get_fitted_player_model_numpyro #578

griff-rees opened this issue Jun 12, 2023 · 2 comments · Fixed by #577
Assignees
Labels
bug Something isn't working

Comments

@griff-rees
Copy link

griff-rees commented Jun 12, 2023

Test test_get_fitted_player_model_numpyro fails in and our of docker in the fix/574-docker-build branch. This may be related to deprecation within the numpyro dependency.

Specific test

 def test_get_fitted_player_model_numpyro():
      pm = NumpyroPlayerModel()
      assert isinstance(pm, NumpyroPlayerModel)
      with test_past_data_session_scope() as ts:
          fpm = fit_player_data("FWD", "1819", 12, model=pm, dbsession=ts)  # Fails here
          assert isinstance(fpm, pd.DataFrame)
          assert len(fpm) > 0

Where the error is raised:

if not_jax_tracer(is_valid):
    if device_get(~jnp.all(is_valid)):
        with numpyro.validation_enabled(), trace() as tr:
            # validate parameters
            substituted_model(*model_args, **model_kwargs)
            # validate values
            for site in tr.values():
                if site["type"] == "sample":
                    with warnings.catch_warnings(record=True) as ws:
                        site["fn"]._validate_sample(site["value"])
                    if len(ws) > 0:
                        for w in ws:
                            # at site information to the warning message
                            w.message.args = (
                                "Site {}: {}".format(
                                    site["name"], w.mes  sage.args[0]
                                ),
                            ) + w.message.args[1:]
                            warnings.showwarning(
                                w.message,
                                w.category,
                                w.filename,
                                w.lineno,
                                file=w.file,
                                line=w.line,
                            )
        raise RuntimeError(
            "Cannot find valid initial parameters. Please check your model again."
        )
        RuntimeError: Cannot find valid initial parameters. Please check your model again. 
@griff-rees
Copy link
Author

See pull request #579

@griff-rees griff-rees added the bug Something isn't working label Jun 12, 2023
@griff-rees griff-rees self-assigned this Jun 12, 2023
@griff-rees
Copy link
Author

Specifying dependencies jax and jaxlib to 0.3.25 in commit dfd76d8 addresses this. Will close when merged.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant