Skip to content

Commit

Permalink
Merge pull request #159 from Juice-jl/f-arithmetic-cleanup
Browse files Browse the repository at this point in the history
cleanup and info theory queries
  • Loading branch information
PoorvaGarg authored Dec 26, 2023
2 parents 9a767df + 8580a78 commit 9ceda87
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
run:
sudo apt update && sudo apt install -y pdf2svg texlive-latex-base texlive-binaries texlive-latex-extra texlive-luatex

- name: Install IRTools master
- name: Instantiate Packages
run: |
julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.build();'
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ DataStructures = "0.18"
DirectedAcyclicGraphs = "0.1.1"
Distributions = "0.25"
Graphs = "1"
IRTools = "0.4.7"
IRTools = "0.4.11"
Jive = "0.2"
MacroTools = "0.5"
PrecompileTools = "1"
Expand Down
29 changes: 1 addition & 28 deletions src/dist/number/uint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ function variance(x::DistUInt{W}; kwargs...) where W
return ans
end


##################################
# methods
##################################
Expand Down Expand Up @@ -422,37 +423,9 @@ end

# Uniform from 0 to hi, exclusive
function unif_half(hi::DistUInt{W})::DistUInt{W} where W
# max_hi = maxvalue(hi)
# max_hi > 60 && error("Likely to time out")
# prod = BigInt(1)
# for prime in primes_at_most(max_hi)
# prod *= prime ^ floor_log(prime, max_hi)
# end

# note: # could use path cond too
prod = lcm([BigInt(x) for x in keys(pr(hi)) if x != 0])
u = uniform(DistUInt{ndigits(prod, base=2)}, 0, prod)
rem_trunc(u, hi)
end

# function primes_at_most(n::Int)
# isprime = [true for _ in 1:n]
# for p in 2:trunc(Int, sqrt(n))
# if isprime[p]
# for i in p^2:p:n
# isprime[i] = false
# end
# end
# end
# [i for i in 2:n if isprime[i]]
# end

# function floor_log(base, n)
# v = 1
# pow = 0
# while v * base <= n
# v *= base
# pow += 1
# end
# pow
# end
31 changes: 29 additions & 2 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export condprobs, condprob, Cudd, CuddDebugInfo, ProbException, allobservations, JointQuery, returnvalue, expectation, variance
export condprobs, condprob, Cudd, CuddDebugInfo, ProbException, allobservations, JointQuery,
returnvalue, expectation, variance, kldivergence, tvdistance, entropy

using DataStructures: DefaultDict, DefaultOrderedDict, OrderedDict

Expand Down Expand Up @@ -44,11 +45,37 @@ function pr(queries...; kwargs...)
for (world, p) in worlds
dist[frombits(query, world)] += p
end
DefaultOrderedDict(0., OrderedDict(sort(collect(dist); by=last, rev=true))) # by decreasing probability
DefaultOrderedDict(0., OrderedDict(sort(collect(dist);
by= t -> (-t[2], t[1])))) # by decreasing probability
end
length(queries) == 1 ? ans[1] : ans
end

"""Compute the entropy of a random variable"""
function entropy(p)
-sum(pr(p)) do (_, prob)
prob * log(prob)
end
end

"""Compute the KL-divergence between two random variables"""
function kldivergence(p, q)
prp = pr(p)
prq = pr(q)
sum(prp) do (value, prob)
prob * (log(prob) - log(prq[value]))
end
end

"""Compute the total variation distance between two random variables"""
function tvdistance(p, q)
prp = pr(p)
prq = pr(q)
0.5 * sum(keys(prp) keys(prq)) do value
abs(prp[value] - prq[value])
end
end

##################################
# Inference with metadata distributions from DSL
##################################
Expand Down
2 changes: 1 addition & 1 deletion src/inference/pr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ function pr(cudd::Cudd, evidence, queries::Vector{JointQuery}, errors, dots)
end

results
end
end
1 change: 1 addition & 0 deletions src/inference/sample.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export sample
using DirectedAcyclicGraphs: foldup

"""Run vanilla rejection sampling without any compilation"""
function sample(x; evidence=true)
vals = Dict{ADNode, Real}()
while true
Expand Down

0 comments on commit 9ceda87

Please # to comment.