Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Better parameter learning and inductive types #191

Merged
merged 232 commits into from
Jul 19, 2024
Merged

Better parameter learning and inductive types #191

merged 232 commits into from
Jul 19, 2024

Conversation

rtjoa
Copy link
Contributor

@rtjoa rtjoa commented Jul 19, 2024

TL;DR

We add support for learning arbitrary objectives in terms of probabilistic queries. To find the value of θ that maximizes pr(flip(θ) & flip(θ) & !flip(θ)):

θ = Var("θ")
x = flip(θ) & flip(θ) & !flip(θ)
var_vals = Valuation=> 0.5)  # intial assignments
loss = -LogPr(x)
train!(var_vals, loss)  # mutates var_vals
@test compute_mixed(var_vals, θ)  2/3

Macros make it easy to work with probabilistic inductive types:

@inductive Nat Z() S(Nat)
function Base.:(+)(x::Nat, y::Nat)
    @match y [
        Z() -> x,
        S(y′) -> S(x) + y′,
    ]
end

To ensure we always have up-to-date documentation, tours of the core of Dice.jl and parameter learning have been added to tests. I recommend first looking at these to get a better sense of the interface.

Better parameter learning

We update autodiff to represent the log probabilities of Dist{Bools}s symbolically, to train arbitrary loss functions instead of just doing MLE. In fact, we support "arbitrary interleavings" - computation dependent on log probabilities can be used as flip parameters, to create more symbolic log probabilities, etc.

The core construct we add to is the struct LogPr(::Dist{Bool}) <: ADNode.

  • To compute an ADNode containing a LogPr, use compute_mixed rather than compute.
  • To perform inference on a Dist containing a flip whose probability is dependent on a LogPr, use pr_mixed rather than pr.
  • train!(::Valuation, loss::ADNode; epochs, learning_rate) updates a valuation (dict from Vars to values) to minimize loss by GD

Examples and tests are given in test/autodiff_pr.

Other improvements

We add the functions sample_as_dist and frombits_as_dist, which work the same as their non-_as_dist counterparts, except they return deterministic Dists instead of Julia primitives (e.g. DistUInt32(3) instead of 3), allowing us to feed the results back into programs (e.g. passing them to prob_equals to check the probability of a particular sample/grounding).

Future work

Ideally, parameter learning integrates with a dedicated AD library like Zygote.jl. However, it requires care to make sure it plays well with CUDD, and we already have our tiny autodiff framework, so this PR does not make things much more complex.

@rtjoa rtjoa marked this pull request as ready for review July 19, 2024 03:29
@rtjoa rtjoa merged commit dcd3406 into main Jul 19, 2024
3 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant