Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into release-0.35
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Jan 28, 2025
2 parents af65bb0 + 29a6c7e commit f5c5fda
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 25 deletions.
15 changes: 9 additions & 6 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ end

const PREFIX_SEPARATOR = Symbol(".")

# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
function PrefixContext{PrefixInner}(
context::PrefixContext{PrefixOuter}
) where {PrefixInner,PrefixOuter}
Expand All @@ -273,13 +274,15 @@ function PrefixContext{PrefixInner}(
end
end

function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
if @generated
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn)))
else
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
end
# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
return prefix(
childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
)
end
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)

"""
prefix(model::Model, x)
Expand Down
23 changes: 12 additions & 11 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,50 +239,51 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
end

function record_varname!(context::DebugContext, varname::VarName, dist)
if haskey(context.varnames_seen, varname)
prefixed_varname = prefix(context, varname)
if haskey(context.varnames_seen, prefixed_varname)
if context.error_on_failure
error("varname $varname used multiple times in model")
error("varname $prefixed_varname used multiple times in model")
else
@warn "varname $varname used multiple times in model"
@warn "varname $prefixed_varname used multiple times in model"
end
context.varnames_seen[varname] += 1
context.varnames_seen[prefixed_varname] += 1
else
# We need to check:
# 1. Does this `varname` subsume any of the other keys.
# 2. Does any of the other keys subsume `varname`.
vns = collect(keys(context.varnames_seen))
# Is `varname` subsumed by any of the other keys?
idx_parent = findfirst(Base.Fix2(subsumes, varname), vns)
idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns)
if idx_parent !== nothing
varname_parent = vns[idx_parent]
if context.error_on_failure
error(
"varname $(varname_parent) used multiple times in model (subsumes $varname)",
"varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)",
)
else
@warn "varname $(varname_parent) used multiple times in model (subsumes $varname)"
@warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)"
end
# Update count of parent.
context.varnames_seen[varname_parent] += 1
else
# Does `varname` subsume any of the other keys?
idx_child = findfirst(Base.Fix1(subsumes, varname), vns)
idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns)
if idx_child !== nothing
varname_child = vns[idx_child]
if context.error_on_failure
error(
"varname $(varname_child) used multiple times in model (subsumed by $varname)",
"varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)",
)
else
@warn "varname $(varname_child) used multiple times in model (subsumed by $varname)"
@warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)"
end

# Update count of child.
context.varnames_seen[varname_child] += 1
end
end

context.varnames_seen[varname] = 1
context.varnames_seen[prefixed_varname] = 1
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false
is_extracting_values(::IsLeaf, ::AbstractContext) = false

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), vn)
return setindex!(context.values, copy(value), prefix(context, vn))
end

function broadcast_push!(context::ValuesAsInModelContext, vns, values)
Expand Down
10 changes: 6 additions & 4 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ struct VarNamedVector{
transforms::TTrans

"""
vector of booleans indicating whether a variable has been transformed to unconstrained
Euclidean space or not, i.e. whether its domain is all of `ℝ^ⁿ`. Having
`is_unconstrained[varname_to_index[vn]] == false` does not necessarily mean that a
variable is constrained, but rather that it's not guaranteed to not be.
vector of booleans indicating whether a variable has been explicitly transformed to
unconstrained Euclidean space, i.e. whether its domain is all of `ℝ^ⁿ`. If
`is_unconstrained[varname_to_index[vn]]` is true, it guarantees that the variable
`vn` is not constrained. However, the converse does not hold: if `is_unconstrained`
is false, the variable `vn` may still happen to be unconstrained, e.g. if its
original distribution is itself unconstrained (like a normal distribution).
"""
is_unconstrained::BitVector

Expand Down
20 changes: 20 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ end
@test getoptic(vn_prefixed) === getoptic(vn)
end

@testset "nested within arbitrary context stacks" begin
vn = @varname(x[1])
ctx1 = PrefixContext{:a}(DefaultContext())
ctx2 = SamplingContext(ctx1)
ctx3 = PrefixContext{:b}(ctx2)
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
vn_prefixed1 = prefix(ctx1, vn)
vn_prefixed2 = prefix(ctx2, vn)
vn_prefixed3 = prefix(ctx3, vn)
vn_prefixed4 = prefix(ctx4, vn)
@test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x")
@test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x")
@test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x")
@test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x")
@test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn)
@test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn)
end

context = DynamicPPL.PrefixContext{:prefix}(SamplingContext())
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Sample with the context.
Expand Down
9 changes: 9 additions & 0 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@
end
model = ModelOuterWorking()
@test check_model(model; error_on_failure=true)

# With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785
@model function ModelOuterWorking2()
x1 ~ to_submodel(prefix(ModelInner(), :a), false)
x2 ~ to_submodel(prefix(ModelInner(), :b), false)
return (x1, x2)
end
model = ModelOuterWorking2()
@test check_model(model; error_on_failure=true)
end

@testset "subsumes (x then x[1])" begin
Expand Down
19 changes: 16 additions & 3 deletions test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Use debug logging below.
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
# They should all result in typed.
@test varinfo isa DynamicPPL.TypedVarInfo
# But let's also make sure that they're not lying.
# Check that the inferred varinfo is indeed suitable for evaluation and sampling
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo
)
Expand All @@ -76,6 +74,21 @@
model, varinfo, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
# For our demo models, they should all result in typed.
is_typed = varinfo isa DynamicPPL.TypedVarInfo
@test is_typed
# If the test failed, check why it didn't infer a typed varinfo
if !is_typed
typed_vi = VarInfo(model)
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, typed_vi
)
JET.test_call(f_eval, argtypes_eval)
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, typed_vi, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
end
end
end
end
21 changes: 21 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "Prefixing" begin
@model inner() = x ~ Normal()

@model function outer_auto_prefix()
a ~ to_submodel(inner(), true)
b ~ to_submodel(inner(), true)
return nothing
end
@model function outer_manual_prefix()
a ~ to_submodel(prefix(inner(), :a), false)
b ~ to_submodel(prefix(inner(), :b), false)
return nothing
end

for model in (outer_auto_prefix(), outer_manual_prefix())
vi = VarInfo(model)
vns = Set(keys(values_as_in_model(model, false, vi)))
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
end
end
end

@testset "Erroneous model call" begin
Expand Down

0 comments on commit f5c5fda

Please # to comment.