diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index b95107b2d..933bfb1d1 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -64,9 +64,7 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # They should all result in typed. - @test varinfo isa DynamicPPL.TypedVarInfo - # But let's also make sure that they're not lying. + # Check that the inferred varinfo is indeed suitable for evaluation and sampling f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) @@ -76,6 +74,21 @@ model, varinfo, DynamicPPL.SamplingContext() ) JET.test_call(f_sample, argtypes_sample) + # For our demo models, they should all result in typed. + is_typed = varinfo isa DynamicPPL.TypedVarInfo + @test is_typed + # If the test failed, check why it didn't infer a typed varinfo + if !is_typed + typed_vi = VarInfo(model) + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, typed_vi + ) + JET.test_call(f_eval, argtypes_eval) + f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, typed_vi, DynamicPPL.SamplingContext() + ) + JET.test_call(f_sample, argtypes_sample) + end end end end