Skip to content

Gibbs test | Fix dynamic model test in Gibbs sampler suite #2579

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
function ExternalSampler(
sampler::AbstractSampler,
adtype::ADTypes.AbstractADType,
::Val{unconstrained}=Val(true),
(::Val{unconstrained})=Val(true),
) where {unconstrained}
if !(unconstrained isa Bool)
throw(
Expand All @@ -44,9 +44,11 @@

Return `true` if the sampler requires unconstrained space, and `false` otherwise.
"""
requires_unconstrained_space(
function requires_unconstrained_space(

Check warning on line 47 in src/mcmc/external_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/external_sampler.jl#L47

Added line #L47 was not covered by tests
::ExternalSampler{<:Any,<:Any,Unconstrained}
) where {Unconstrained} = Unconstrained
) where {Unconstrained}
return Unconstrained

Check warning on line 50 in src/mcmc/external_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/external_sampler.jl#L50

Added line #L50 was not covered by tests
end

"""
externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)
Expand Down
10 changes: 5 additions & 5 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ using Turing
chain = sample(StableRNG(seed), gauss(x), PG(10), 10)
chain = sample(StableRNG(seed), gauss(x), SMC(), 10)

@model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV}
@model function gauss2((::Type{TV})=Vector{Float64}; x) where {TV}
priors = TV(undef, 2)
priors[1] ~ InverseGamma(2, 3) # s
priors[2] ~ Normal(0, sqrt(priors[1])) # m
Expand All @@ -321,7 +321,7 @@ using Turing
StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10
)

@model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV}
@model function gauss3(x, (::Type{TV})=Vector{Float64}) where {TV}
priors = TV(undef, 2)
priors[1] ~ InverseGamma(2, 3) # s
priors[2] ~ Normal(0, sqrt(priors[1])) # m
Expand Down Expand Up @@ -548,7 +548,7 @@ using Turing
N = 10
alg = HMC(0.01, 5)
x = randn(1000)
@model function vdemo1(::Type{T}=Float64) where {T}
@model function vdemo1((::Type{T})=Float64) where {T}
x = Vector{T}(undef, N)
for i in 1:N
x[i] ~ Normal(0, sqrt(4))
Expand All @@ -563,7 +563,7 @@ using Turing
vdemo1kw(; T) = vdemo1(T)
sample(StableRNG(seed), vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 10)

@model function vdemo2(::Type{T}=Float64) where {T<:Real}
@model function vdemo2((::Type{T})=Float64) where {T<:Real}
x = Vector{T}(undef, N)
@. x ~ Normal(0, 2)
end
Expand All @@ -574,7 +574,7 @@ using Turing
vdemo2kw(; T) = vdemo2(T)
sample(StableRNG(seed), vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 10)

@model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
@model function vdemo3((::Type{TV})=Vector{Float64}) where {TV<:AbstractVector}
x = TV(undef, N)
@. x ~ InverseGamma(2, 3)
end
Expand Down
101 changes: 61 additions & 40 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
end

# A test model that includes several different kinds of tilde syntax.
@model function test_model(val, ::Type{M}=Vector{Float64}) where {M}
@model function test_model(val, (::Type{M})=Vector{Float64}) where {M}
s ~ Normal(0.1, 0.2)
m ~ Poisson()
val ~ Normal(s, 1)
Expand Down Expand Up @@ -507,47 +507,68 @@ end
sample(model, alg, 100; callback=callback)
end

@testset "dynamic model" begin
@model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M}
N = length(y)
rpm = DirichletProcess(alpha)

z = zeros(Int, N)
cluster_counts = zeros(Int, N)
fill!(cluster_counts, 0)

for i in 1:N
z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
cluster_counts[z[i]] += 1
end

Kmax = findlast(!iszero, cluster_counts)
m = M(undef, Kmax)
for k in 1:Kmax
m[k] ~ Normal(1.0, 1.0)
@testset "dynamic model with analytical posterior" begin
# A dynamic model where b ~ Bernoulli determines the dimensionality
# When b=0: single parameter θ₁
# When b=1: two parameters θ₁, θ₂ where we observe their sum
@model function dynamic_bernoulli_normal(y_obs=2.0)
b ~ Bernoulli(0.3)

if b == 0
θ = Vector{Float64}(undef, 1)
θ[1] ~ Normal(0.0, 1.0)
y_obs ~ Normal(θ[1], 0.5)
else
θ = Vector{Float64}(undef, 2)
θ[1] ~ Normal(0.0, 1.0)
θ[2] ~ Normal(0.0, 1.0)
y_obs ~ Normal(θ[1] + θ[2], 0.5)
end
end
num_zs = 100
num_samples = 10_000
model = imm(Random.randn(num_zs), 1.0)
# https://github.com/TuringLang/Turing.jl/issues/1725
# sample(model, Gibbs(:z => MH(), :m => HMC(0.01, 4)), 100);

# Run the sampler - focus on testing that it works rather than exact convergence
model = dynamic_bernoulli_normal(2.0)
chn = sample(
StableRNG(23), model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), num_samples
StableRNG(42),
model,
Gibbs(:b => MH(), :θ => HMC(0.1, 10)),
1000;
discard_initial=500,
)
# The number of m variables that have a non-zero value in a sample.
num_ms = count(ismissing.(Array(chn[:, (num_zs + 1):end, 1])); dims=2)
# The below are regression tests. The values we are comparing against are from
# running the above model on the "old" Gibbs sampler that was in place still on
# 2024-11-20. The model was run 5 times with 10_000 samples each time. The values
# to compare to are the mean of those 5 runs, atol is roughly estimated from the
# standard deviation of those 5 runs.
# TODO(mhauru) Could we do something smarter here? Maybe a dynamic model for which
# the posterior is analytically known? Doing 10_000 samples to run the test suite
# is not ideal
# Issue ref: https://github.com/TuringLang/Turing.jl/issues/2402
@test isapprox(mean(num_ms), 8.6087; atol=0.8)
@test isapprox(std(num_ms), 1.8865; atol=0.03)

# Test that sampling completes without error
@test size(chn, 1) == 1000

# Test that both states are explored (basic functionality test)
b_samples = chn[:b]
unique_b_values = unique(skipmissing(b_samples))
@test length(unique_b_values) >= 1 # At least one value should be sampled

# Test that θ[1] values are reasonable when they exist
theta1_samples = collect(skipmissing(chn[:, Symbol("θ[1]"), 1]))
if length(theta1_samples) > 0
@test all(isfinite, theta1_samples) # All samples should be finite
@test std(theta1_samples) > 0.1 # Should show some variation
end

# Test that when b=0, only θ[1] exists, and when b=1, both θ[1] and θ[2] exist
theta2_col_exists = Symbol("θ[2]") in names(chn)
if theta2_col_exists
theta2_samples = chn[:, Symbol("θ[2]"), 1]
# θ[2] should have some missing values (when b=0) and some non-missing (when b=1)
n_missing_theta2 = sum(ismissing.(theta2_samples))
n_present_theta2 = sum(.!ismissing.(theta2_samples))

# At least some θ[2] values should be missing (corresponding to b=0 states)
# This is a basic structural test - we're not testing exact analytical results
@test n_missing_theta2 > 0 || n_present_theta2 > 0 # One of these should be true
end
end

# Helper function for logsumexp
function logsumexp(x)
max_x = maximum(x)
return max_x + log(sum(exp.(x .- max_x)))
end
Comment on lines +568 to 572
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but it might be better to use LogExpFunctions.logsumexp here rather than rolling our own version.

Copy link
Member

@yebai yebai Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


# The below test used to sample incorrectly before
Expand All @@ -574,7 +595,7 @@ end

@testset "dynamic model with dot tilde" begin
@model function dynamic_model_with_dot_tilde(
num_zs=10, ::Type{M}=Vector{Float64}
num_zs=10, (::Type{M})=Vector{Float64}
) where {M}
z = Vector{Int}(undef, num_zs)
z .~ Poisson(1.0)
Expand Down Expand Up @@ -720,7 +741,7 @@ end
struct Wrap{T}
a::T
end
@model function model1(::Type{T}=Float64) where {T}
@model function model1((::Type{T})=Float64) where {T}
x = Vector{T}(undef, 1)
x[1] ~ Normal()
y = Wrap{T}(0.0)
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ using Turing
end

@testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
@model function vector_of_dirichlet((::Type{TV})=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
return xs[2] ~ Dirichlet(ones(5))
Expand Down