Skip to content

Ah/gibbs closed form cond #2597

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

AoifeHughes
Copy link

Addresses -> #2547

Summary

This PR implements analytical conditional sampling support for the GibbsConditional sampler, restoring and enhancing functionality that was mentioned in the HISTORY.md. The new interface allows users to specify analytical conditional distributions as functions,
enabling more efficient Gibbs sampling when conditional distributions are known in closed form.

New Interface

GibbsConditional(sym, conditional)

Where:

  • sym: A variable name (Symbol or VarName) to be sampled
  • conditional: A function that takes a NamedTuple of conditioned variables and returns a Distribution

Example Usage

  # Define conditional functions
  function cond_m(c)
      λ_n = c.λ * (N + 1)
      σ_n = sqrt(1 / λ_n)
      return Normal(m_n, σ_n)
  end

  function cond_λ(c)
      α_n = α_0 + (N - 1) / 2 + 1
      β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0)
      return Gamma(α_n, inv(β_n))
  end

  # Use in Gibbs sampler
  sampler = Gibbs(
      GibbsConditional(:λ, cond_λ),
      GibbsConditional(:m, cond_m)
  )

  chain = sample(model, sampler, 1000)

Key Features

  • Backward Compatible: Existing GibbsConditional(sampler, varnames) interface unchanged
  • Mixed Usage: Can be combined with regular sampler-based GibbsConditional and Pair syntax
  • Type Safe: Validates conditional functions and provides clear error messages
  • Efficient: Direct sampling from analytical distributions without MCMC overhead

Implementation Details

Files Modified

  • src/mcmc/gibbs.jl: Core implementation with extended struct and sampling logic
  • src/mcmc/Inference.jl: Added exports for new helper functions
  • src/Turing.jl: Added exports to main module
  • test/mcmc/gibbs.jl: Comprehensive test suite including the inverse Gamma-Normal example

New Types and Functions

  • Extended GibbsConditional{S,V,C} struct with conditional function field
  • AnalyticalConditionalState for proper state management
  • sample_analytical_conditional() for core sampling logic
  • is_analytical_conditional() and is_sampler_conditional() helper functions

@AoifeHughes AoifeHughes self-assigned this Jun 23, 2025
::ExternalSampler{<:Any,<:Any,Unconstrained}
) where {Unconstrained} = Unconstrained
) where {Unconstrained}
Unconstrained
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
Unconstrained
return Unconstrained


# Add wrap_in_sampler method for analytical GibbsConditional
function wrap_in_sampler(gc::GibbsConditional)
is_analytical_conditional(gc) ? gc : wrap_in_sampler(gc.sampler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
is_analytical_conditional(gc) ? gc : wrap_in_sampler(gc.sampler)
return is_analytical_conditional(gc) ? gc : wrap_in_sampler(gc.sampler)

Comment on lines +144 to +146
unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = DynamicPPL.Sampler(
sampler.alg.inner
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = DynamicPPL.Sampler(
sampler.alg.inner
)
unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) =
DynamicPPL.Sampler(sampler.alg.inner)


@model function simple_model()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
y ~ Normal(x, 1)
return y ~ Normal(x, 1)

Comment on lines +110 to +111
wo ~
MvNormal([0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
wo ~
MvNormal([0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior])
wo ~ MvNormal(
[0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
)

Comment on lines +36 to +38
DynamicPPL.setchildcontext(parent::OverrideContext, child) = OverrideContext(
child, parent.logprior_weight, parent.loglikelihood_weight
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
DynamicPPL.setchildcontext(parent::OverrideContext, child) = OverrideContext(
child, parent.logprior_weight, parent.loglikelihood_weight
)
DynamicPPL.setchildcontext(parent::OverrideContext, child) =
OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight)

Comment on lines +219 to +221
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
check_success(result, check_retcode=true) =
check_optimisation_result(result, true_value, true_logp, check_retcode)

Comment on lines +280 to +282
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
check_success(result, check_retcode=true) =
check_optimisation_result(result, true_value, true_logp, check_retcode)

Comment on lines +342 to +344
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
check_success(result, check_retcode=true) =
check_optimisation_result(result, true_value, true_logp, check_retcode)

Comment on lines +396 to +398
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
check_success(result, check_retcode=true) = check_optimisation_result(
result, true_value, true_logp, check_retcode
)
check_success(result, check_retcode=true) =
check_optimisation_result(result, true_value, true_logp, check_retcode)

Copy link
Contributor

Turing.jl documentation for PR #2597 is available at:
https://TuringLang.github.io/Turing.jl/previews/PR2597/

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant