-
Notifications
You must be signed in to change notification settings - Fork 226
AdvancedPS v0.7 (and thus Libtask v0.9) support #2585
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
The tests that I had the patience to run locally now pass. Waiting for the AdvancedPS release to be able to run the full test suite on CI. Some indicators of speed: julia> module MWE
using Turing
@model function gdemo(x, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
return s, m
end
@time chn = sample(gdemo(2.5, 1.0), PG(10), 10_000)
describe(chn)
end On main:
On this branch:
julia> module MWE
using Turing
@model function f(dim=20, ::Type{T}=Float64) where T
s = Vector{Bool}(undef, dim)
x = Vector{T}(undef, dim)
for i in 1:dim
s[i] ~ Bernoulli()
if s[i]
x[i] ~ Normal()
else
x[i] ~ Beta()
end
0.0 ~ Normal(x[i])
end
return nothing
end
alg = Gibbs(
@varname(s)=>PG(10),
@varname(x)=>HMC(0.1, 5),
)
@time chn = sample(f(), alg, 1_000)
end On main:
On this branch:
Obviously the speed gains are all due to @willtebbutt's fantastic work on Libtask, everything else is just wrapping that work. |
Turing.jl documentation for PR #2585 is available at: |
@@ -85,7 +85,7 @@ Statistics = "1.6" | |||
StatsAPI = "1.6" | |||
StatsBase = "0.32, 0.33, 0.34" | |||
StatsFuns = "0.8, 0.9, 1" | |||
julia = "1.10.2" | |||
julia = "1.10.8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Libtask requires 1.10.8 at a minimum.
@@ -402,11 +391,11 @@ end | |||
|
|||
function trace_local_varinfo_maybe(varinfo) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.model.f.varinfo | |||
trace = Libtask.get_taped_globals(Any).other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we change Libtask.get_taped_globals
to return nothing
if not inside a running TapedTask
, the following try .. catch ... end
can be removed.
@@ -416,11 +405,10 @@ end | |||
|
|||
function trace_local_rng_maybe(rng::Random.AbstractRNG) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.rng | |||
return Libtask.get_taped_globals(Any).rng |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with above.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2585 +/- ##
===========================================
- Coverage 85.57% 50.44% -35.13%
===========================================
Files 22 22
Lines 1456 1447 -9
===========================================
- Hits 1246 730 -516
- Misses 210 717 +507 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 15835391573Details
💛 - Coveralls |
Is this reviewable? The tests are failing, there's a method ambiguity that Aqua complains about, there's a Gibbs failure on 1.12 which should be disabled with
I don't want to speak for @mhauru in his absence but last time we spoke about this PR, it was clear that there were still a few gaps to bridge. If I were to review it at this stage, my sole comment would be to fix the tests. |
The complement PR of TuringLang/AdvancedPS.jl#114, which adds support for the newly rewritten Libtask.
Work in progress, currently blocked by TuringLang/Libtask.jl#186