From 7a140bc7aedb1bc3e94141d650c08e6736780fbf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 27 Jan 2025 13:09:55 +0000 Subject: [PATCH 1/3] Expand JET test (#782) * Fix JET test * Add output if a typed varinfo isn't inferred --- test/ext/DynamicPPLJETExt.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index b95107b2d..933bfb1d1 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -64,9 +64,7 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # They should all result in typed. - @test varinfo isa DynamicPPL.TypedVarInfo - # But let's also make sure that they're not lying. + # Check that the inferred varinfo is indeed suitable for evaluation and sampling f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) @@ -76,6 +74,21 @@ model, varinfo, DynamicPPL.SamplingContext() ) JET.test_call(f_sample, argtypes_sample) + # For our demo models, they should all result in typed. + is_typed = varinfo isa DynamicPPL.TypedVarInfo + @test is_typed + # If the test failed, check why it didn't infer a typed varinfo + if !is_typed + typed_vi = VarInfo(model) + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, typed_vi + ) + JET.test_call(f_eval, argtypes_eval) + f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, typed_vi, DynamicPPL.SamplingContext() + ) + JET.test_call(f_sample, argtypes_sample) + end end end end From 00e7ee3b6514e1588e95ab8c59807c58208490b1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 27 Jan 2025 13:10:34 +0000 Subject: [PATCH 2/3] Clarify is_unconstrained docstring (#790) --- src/varnamedvector.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index b324e9134..565e82480 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -159,10 +159,12 @@ struct VarNamedVector{ transforms::TTrans """ - vector of booleans indicating whether a variable has been transformed to unconstrained - Euclidean space or not, i.e. whether its domain is all of `ℝ^ⁿ`. Having - `is_unconstrained[varname_to_index[vn]] == false` does not necessarily mean that a - variable is constrained, but rather that it's not guaranteed to not be. + vector of booleans indicating whether a variable has been explicitly transformed to + unconstrained Euclidean space, i.e. whether its domain is all of `ℝ^ⁿ`. If + `is_unconstrained[varname_to_index[vn]]` is true, it guarantees that the variable + `vn` is not constrained. However, the converse does not hold: if `is_unconstrained` + is false, the variable `vn` may still happen to be unconstrained, e.g. if its + original distribution is itself unconstrained (like a normal distribution). """ is_unconstrained::BitVector From 29a6c7ec0cd62d3a4d1dc18a304d5e4d1e024cfb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 27 Jan 2025 15:23:49 +0000 Subject: [PATCH 3/3] Handle nested PrefixContext (#787) * Prefix varnames appropriately inside check_model_and_trace * Fix values_as_in_model as well * Add test for check_model with manual prefix * Add values_as_in_model tests * Add tests for prefix nesting * Bump Project.toml --- Project.toml | 2 +- src/contexts.jl | 15 +++++++++------ src/debug_utils.jl | 23 ++++++++++++----------- src/values_as_in_model.jl | 2 +- test/contexts.jl | 20 ++++++++++++++++++++ test/debug_utils.jl | 9 +++++++++ test/model.jl | 21 +++++++++++++++++++++ 7 files changed, 73 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index bd553c0cc..3df611824 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.34.1" +version = "0.34.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/contexts.jl b/src/contexts.jl index a9470fbb6..99b2136f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -261,6 +261,7 @@ end const PREFIX_SEPARATOR = Symbol(".") +# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here function PrefixContext{PrefixInner}( context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} @@ -273,13 +274,15 @@ function PrefixContext{PrefixInner}( end end -function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn))) - else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) - end +# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here +function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} + return prefix( + childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) + ) end +prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn) """ prefix(model::Model, x) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index f486482a9..43b5054d5 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -239,42 +239,43 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - if haskey(context.varnames_seen, varname) + prefixed_varname = prefix(context, varname) + if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure - error("varname $varname used multiple times in model") + error("varname $prefixed_varname used multiple times in model") else - @warn "varname $varname used multiple times in model" + @warn "varname $prefixed_varname used multiple times in model" end - context.varnames_seen[varname] += 1 + context.varnames_seen[prefixed_varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. vns = collect(keys(context.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] if context.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $varname)", + "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" end # Update count of parent. context.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) if idx_child !== nothing varname_child = vns[idx_child] if context.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $varname)", + "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" end # Update count of child. @@ -282,7 +283,7 @@ function record_varname!(context::DebugContext, varname::VarName, dist) end end - context.varnames_seen[varname] = 1 + context.varnames_seen[prefixed_varname] = 1 end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index ca8cc1cb3..4cef5fa4e 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false is_extracting_values(::IsLeaf, ::AbstractContext) = false function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), vn) + return setindex!(context.values, copy(value), prefix(context, vn)) end function broadcast_push!(context::ValuesAsInModelContext, vns, values) diff --git a/test/contexts.jl b/test/contexts.jl index dd3b4c90c..ef55335d0 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -162,6 +162,26 @@ end @test getoptic(vn_prefixed) === getoptic(vn) end + @testset "nested within arbitrary context stacks" begin + vn = @varname(x[1]) + ctx1 = PrefixContext{:a}(DefaultContext()) + ctx2 = SamplingContext(ctx1) + ctx3 = PrefixContext{:b}(ctx2) + ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + vn_prefixed1 = prefix(ctx1, vn) + vn_prefixed2 = prefix(ctx2, vn) + vn_prefixed3 = prefix(ctx3, vn) + vn_prefixed4 = prefix(ctx4, vn) + @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") + @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") + @test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x") + @test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x") + @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) + end + context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Sample with the context. diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 294364758..d4f6601f5 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -60,6 +60,15 @@ end model = ModelOuterWorking() @test check_model(model; error_on_failure=true) + + # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 + @model function ModelOuterWorking2() + x1 ~ to_submodel(prefix(ModelInner(), :a), false) + x2 ~ to_submodel(prefix(ModelInner(), :b), false) + return (x1, x2) + end + model = ModelOuterWorking2() + @test check_model(model; error_on_failure=true) end @testset "subsumes (x then x[1])" begin diff --git a/test/model.jl b/test/model.jl index 45c770cc4..118f60a40 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,6 +429,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end + + @testset "Prefixing" begin + @model inner() = x ~ Normal() + + @model function outer_auto_prefix() + a ~ to_submodel(inner(), true) + b ~ to_submodel(inner(), true) + return nothing + end + @model function outer_manual_prefix() + a ~ to_submodel(prefix(inner(), :a), false) + b ~ to_submodel(prefix(inner(), :b), false) + return nothing + end + + for model in (outer_auto_prefix(), outer_manual_prefix()) + vi = VarInfo(model) + vns = Set(keys(values_as_in_model(model, false, vi))) + @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + end + end end @testset "Erroneous model call" begin