Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

use DistFix #228

Open
github-actions bot opened this issue Dec 25, 2024 · 0 comments
Open

use DistFix #228

github-actions bot opened this issue Dec 25, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

###############################################################################

Thanks to REINFORCE, none of these functions need be implemented in Dice

We instead compare the squared approx-inverse-sqrt to the inverse.

Maybe this helps the story, but it also makes "relative error" a bit harder to

explain so maybe we should just use sqrt.

###############################################################################

###############################################################################

# Probabilistic value from zero to one. TODO: use DistFix

using Revise
using Dice
using Random

# Finding accuracy errors in an approximate inverse square root function

# Current output of this program (nondeterministic):
# 15.127460 seconds (141.17 M allocations: 6.819 GiB, 6.44% gc time, 19.28% compilation time)
# Initial p(bugfound) = 0.011
# Trained p(bugfound) = 0.99

################################################################################
# HYPERPARAMETERS AND CONFIG
################################################################################

# How many steps of Newton's method to take. Increase to make the approximate fn
# more accurate
STEPS = 5

# Precision of the fixed point number generator. Increase to be able to generate
# numbers closer to 0
PREC=10

# We consider a bug to be found if we get relative error BUG_THRESH
# Note: relative error is from approx_inv_sqrt(x)^2 to 1/x. This prevents us
# from needing a preexisting implementation of sqrt.
# Maybe this helps the story, but it also makes "relative error" a bit harder to
# explain so maybe we should just use `sqrt`, and measure error from
# approx_inv_sqrt(x) to 1/sqrt(x).
BUG_THRESH = 0.01

# Training
NUM_EPOCHS = 100
NUM_SAMPLES = 1000
LR = 0.03

################################################################################
# DistZeroToOne
################################################################################

# Probabilistic value from zero to one. TODO: use DistFix
import Dice: tobits, frombits, prob_equals
struct DistZeroToOne <: Dist{Any}
    mantissa::DistUInt
end
tobits(x::DistZeroToOne) = tobits(x.mantissa)
frombits(x::DistZeroToOne, world) =
    float(frombits(x.mantissa, world)) / 2^float(length(x.mantissa.bits))
DistZeroToOne(x::Float64, W) = DistZeroToOne(DistUInt{W}(Int(x * 2^W)))
prob_equals(x::DistZeroToOne, y::DistZeroToOne) =
    prob_equals(x.mantissa, y.mantissa)


################################################################################
# Approximate inverse square root and its error
################################################################################

# Thanks to REINFORCE, none of these functions need be implemented in Dice

# Approximate sqrt(x) by Newton's method
function approx_sqrt(x, steps)
    guess = 1
    for _ in 1:steps
        guess = 1/2 * (guess + x/guess)
    end
    guess
end

approx_inv_sqrt(x, steps) = 1/approx_sqrt(x, steps)

rel_error(actual, expected) = abs((actual - expected) / expected)

# Note we don't need an "oracle" (an existing `sqrt`` function) to target error!
# We instead compare the squared approx-inverse-sqrt to the inverse.
# Maybe this helps the story, but it also makes "relative error" a bit harder to
# explain so maybe we should just use `sqrt`.
approx_inv_sqrt_error(x, steps) = rel_error(approx_inv_sqrt(x, steps)^2, 1/x)

################################################################################
# Generator
################################################################################

var_vals = Valuation()
adnodes_of_interest = Dict{String, ADNode}()
function register_weight!(s)
    var = Var("$(s)_before_sigmoid")
    var_vals[var] = 0
    weight = sigmoid(var)
    adnodes_of_interest[s] = weight
    weight
end

# Uniform from 0 to 1
g = DistZeroToOne(DistUInt{PREC}([
    flip(register_weight!("x_$(i)"))
    for i in 1:PREC
]))

################################################################################
# Train to maximize error
################################################################################

history_err = []
history_p_bugfound = []

@time for epoch in 1:NUM_EPOCHS
    samples = Dice.with_concrete_ad_flips(var_vals, g) do
      [sample(Random.default_rng(), g) for _ in 1:NUM_SAMPLES]
    end

    l = Dice.LogPrExpander(WMC(BDDCompiler([
        prob_equals(g, DistZeroToOne(sample, PREC))
        for sample in samples
    ])))
    loss, total_dist, valid_samples, samples_finding_bug = sum(
        begin
            lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, DistZeroToOne(sample, PREC))))
            dist = approx_inv_sqrt_error(sample, STEPS)
            if isnan(dist)
                [Dice.Constant(0), 0, 0, 0]
            else
                [-lpr_eq * dist, dist, 1, if dist > BUG_THRESH 1 else 0 end]
            end
        end
        for sample in samples
    )
    push!(history_err, total_dist / valid_samples)
    push!(history_p_bugfound, samples_finding_bug / NUM_SAMPLES)

    vals, derivs = differentiate(
        var_vals,
        Derivs(loss => 1)
    )
    for (adnode, d) in derivs
        if adnode isa Var
            var_vals[adnode] -= d * LR
        end
    end
end

println("Initial p(bugfound) = $(first(history_p_bugfound))")
println("Trained p(bugfound) = $(last(history_p_bugfound))")
@github-actions github-actions bot added the todo label Dec 25, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants