-
Notifications
You must be signed in to change notification settings - Fork 226
Update to the AdvancedVI@0.3 interface #2506
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nto update_advancedvi
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nto update_advancedvi
@Red-Portal, can you fix the tests before I take a look? |
@yebai I marked the PR as a draft so that we can first agree on an interface, and then I flesh out the implementation and the tests. Do we wish we proceed in another way? |
Let's address the interface later or in a separate PR since that might require more discussions. For this PR, let's try to keep the VI interface non-breaking where possible. |
…nto update_advancedvi
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal can you take a look at the following error:
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2506 +/- ##
===========================================
- Coverage 86.04% 55.13% -30.92%
===========================================
Files 21 21
Lines 1455 1471 +16
===========================================
- Hits 1252 811 -441
- Misses 203 660 +457 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal. I left some comments below.
One high-level comment: I suggest we unify vi_fullrank_gaussian
and vi_meanfield_gaussian
into a single function, q_distribution(...; gaussian=true|false, fullrank=true|false)
, to reduce code redundancy.
Note on CI errors:
- CI compains about missing
ADVI()
andTruncatedADAGrad()
.
|
||
if isfinite(energy) | ||
return scale | ||
elseif n_trial == num_max_trials |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elseif n_trial == num_max_trials | |
else |
end | ||
end | ||
|
||
function meanfield_gaussian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function meanfield_gaussian( | |
function q_meanfield_gaussian( |
return Bijectors.transformed(q, Bijectors.inverse(b)) | ||
end | ||
|
||
function meanfield_gaussian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function meanfield_gaussian( | |
function q_meanfield_gaussian( |
return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) | ||
end | ||
|
||
function fullrank_gaussian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function fullrank_gaussian( | |
function q_fullrank_gaussian( |
# Use linked `varinfo` to determine the correct number of parameters. | ||
# TODO: Replace with `length` once this is implemented for `VarInfo`. | ||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||
num_params = length(varinfo_linked[:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get the dimentionality via num_params = length(varinfo_linked)
instead of length(varinfo_linked[:])
?
cc @mkarikom
end | ||
|
||
# VI algorithms | ||
include("advi.jl") | ||
function fullrank_gaussian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function fullrank_gaussian( | |
function q_fullrank_gaussian( |
return reshape_outer ∘ f ∘ reshape_inner | ||
end | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move Bijectors.bijector(model::DynamicPPL.Model,...)
to DynamicPPL.
cc @mhauru
Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces | ||
a vector of length `prod(Bijectors.output(f, in_size))`. | ||
""" | ||
function wrap_in_vec_reshape(f, in_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function is only used once, I suggest we inline it and add comments to explain its behaviour.
Sorry for the delay! I've been traveling in the past weeks, but will start working on this now |
This PR aims to update Turing's
Variational
module to match AdvancedVI's new interface starting fromv0.3
. I will try not to change the interface too much, but given the new features inAdvancedVI
, I think breaking changes will be inevitable. Though the focus will be to provide a good default setting rather than to expose all the features.Currently proposed interface:
Closes #2507
Closes #2508
Closes #2430