Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

special-case enum like types #189

Open
github-actions bot opened this issue May 14, 2024 · 0 comments
Open

special-case enum like types #189

github-actions bot opened this issue May 14, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

# TODO: special-case enum like types

                if param  seen && hasmethod(variants, (Type{param},))
                    push!(seen, param)
                    push!(to_visit, param)
                    push!(tys, param)
                end
            end
        end
    end
    reverse!(tys) # top order

    type_ctor_to_id = Dict()
    for ty in tys
        for (ctor, _) in variants(ty)
            type_ctor_to_id[(ty, ctor)] = length(type_ctor_to_id)
        end
    end

    tys, type_ctor_to_id
end

function generate(rs::RunState, p, track_return)
    tys, type_ctor_to_id = collect_types(p.root_ty)
    type_to_gen = Dict()
    for ty in tys
        type_to_gen[ty] = (size, stack_tail) -> begin
            zero_prefix = if size == 0 "0_" else "" end 
            dependents = (size, stack_tail)
            frequency_for(rs,  "$(zero_prefix)$(ty)_variant", dependents, [
                "$(ctor)" => ctor([
                    if param == ty
                        # TODO: special-case enum like types
                        # TODO: if recursing, pass values of sibling *enumlikes*
                        type_to_gen[param](
                            size - 1,
                            update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                        )
                    elseif param  tys
                        # TODO: special-case enum like types
                        type_to_gen[param](
                            p.ty_sizes[param],
                            update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                        )
                    elseif param == AnyBool
                        flip_for(rs, "$(zero_prefix)$(ty)_$(ctor)_$(i)", dependents)
                    elseif param == DistUInt32
                         sum(
                            @dice_ite if flip_for(rs, "$(zero_prefix)$(ty)_$(ctor)_$(i)_num$(n)", dependents)
                                DistUInt32(n)
                            else
                                DistUInt32(0)
                            end
                            for n in twopowers(p.intwidth)
                        )
                    else
                        error()
                    end
                    for (i, param) in enumerate(params)
                ]...)
                for (ctor, params) in variants(ty)
                if size != 0 || all(param != ty for param in params) 
            ])
        end
    end
    type_to_gen[p.root_ty](p.ty_sizes[p.root_ty], empty_stack(p))
end

to_coq(::Type{DistUInt32}) = "nat"
to_coq(::Type{DistInt32}) = "Z"
to_coq(::Type{AnyBool}) = "bool"

function sandwichjoin(pairs; middle, sep)
    ls = []
    rs = []
    for (l, r) in pairs
        push!(ls, l)
        push!(rs, r)
    end
    reverse!(rs)
    join(
        Iterators.flatten([
            ls, [middle], rs
        ]), sep
    )
end

function derived_to_coq(p, adnodes_vals, io)
    matchid_to_cases = Dict()
    for (name, val) in adnodes_vals
        matchid, case = split(name, "%%")
        case = "(" * join([tocoq(eval(Meta.parse(x))) for x in split(case, "%")], ", ") * ")"
        val = thousandths(val)
        push!(get!(matchid_to_cases, matchid, []), (case, val))
    end

    tys, type_ctor_to_id = collect_types(p.root_ty)

    workload = workload_of(typeof(p))
    generators = []

    stack_vars = ["(stack$(i) : nat)" for i in 1:p.stack_size]
    function mk_match(matchid)
        cases = matchid_to_cases[matchid]
        cases = sort(cases)
        "match (size, ($(join(stack_vars, ", ")))) with 
$(join([" " ^ 9 * "| $(name) => $(w)" for (name, w) in cases], "\n"))
         | _ => 500
         end"
    end

    update_stack_vars(loc) = join(stack_vars[2:end], " ") * " $(loc)"
    variants2(ty, zero_case) = if zero_case
        [
            (ctor, params)
            for (ctor, params) in variants(ty)
            if all(param != ty for param in params) 
        ]
    else
        variants(ty)
    end


    for ty in tys
        push!(generators, "
Fixpoint gen_$(to_coq(ty)) (size : nat) $(join(stack_vars, " ")) : G $(to_coq(ty)) :=
  match size with
$(join([
"  | $(if zero_case 0 else "S size'" end) => 
    $(if length(variants2(ty, zero_case)) > 1 "freq [" else "" end)
    $(join([
"    (* $(ctor) *)

    $(if length(variants2(ty, zero_case)) > 1
        "(
         $(mk_match("$(if zero_case "0_" else "" end)$(ty)_variant_$(ctor)")),
         " else "" end)
            $(sandwichjoin(
                Iterators.flatten([
                if param == ty
                    ["bindGen (gen_$(to_coq(param)) size' $(
                        update_stack_vars(type_ctor_to_id[(ty, ctor)])
                    )) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param  tys
                    ["bindGen (gen_$(to_coq(param)) $(p.ty_sizes[param]) $(
                        update_stack_vars(type_ctor_to_id[(ty, ctor)])
                    )) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param == AnyBool
                    ["let weight_true := $(mk_match("$(if zero_case "0_" else "" end)$(ty)_$(ctor)_$(i)")) in
                    bindGen (freq [
                        (weight_true, true);
                        (1000-weight_true, false)
                    ]) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param == DistUInt32
                    [
                        "let weight_$(n) := $(mk_match("$(if zero_case "0_" else "" end)$(ty)_$(ctor)_$(i)_num$(n)")) in
                        bindGen (freq [
                            (weight_$(n), returnGen $(n));
                            (1000-weight_$(n), returnGen 0)
                        ])
                        (fun n$(n) : nat => $(if j == p.intwidth "
                        let p$(i) := $(join(["n$(n)" for n in twopowers(p.intwidth)], "+ ")) in " else "" end)
                        " => ")"
                        for (j, n) in enumerate(twopowers(p.intwidth))
                    ]
                else
                    error()
                end
                for (i, param) in enumerate(params)
                ]),
            middle="returnGen ($(ctor) $(join(["p$(i)" for i in 1:length(params)], " ")))",
            sep="\n"))
    $(if length(variants2(ty, zero_case)) > 1 ")" else "" end)
        "
        for (ctor, params) in variants2(ty, zero_case)
    ], ";\n"))
    $(if length(variants2(ty, zero_case)) > 1 "]" else "" end)"
    for zero_case in [true, false]
  ], "\n" ))
    end.")
    end

    before, after = sandwich(workload)
    "$(before)
    $(join(generators, "\n"))

Definition gSized :=
  gen_$(to_coq(p.root_ty)) $(p.ty_sizes[p.root_ty])$(" 0" ^ p.stack_size).

    $(after)"
end
@github-actions github-actions bot added the todo label May 14, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants