Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix unconditional sampling and add unit test #332

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ def lookup(
# That is, we are gathering the inputs of input layers
layer = entry.module
assert isinstance(layer, TorchInputLayer)
if not layer.num_variables:
# Pass the wanted batch dimension to constant layers
yield layer, (1 if in_graph is None else in_graph.shape[0],)
else:
if layer.num_variables:
if in_graph is None:
yield layer, ()
continue
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., layer.scope_idx].permute(2, 1, 0, 3)
yield layer, (x,)
continue

# Pass the wanted batch dimension to constant layers
yield layer, (1 if in_graph is None else in_graph.shape[0],)

@classmethod
def from_index_info(
Expand Down
3 changes: 1 addition & 2 deletions tests/backend/torch/test_compile_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def check_continuous_ground_truth(
assert isclose(int_tc(), semiring.map_from(torch.tensor(gt_partition_func), SumProductSemiring))
df = lambda y, x: torch.exp(tc(torch.Tensor([[[x, y]]]))).squeeze()
int_a, int_b = -np.inf, np.inf
ig, err = integrate.dblquad(df, int_a, int_b, int_a, int_b)
ig, err = integrate.dblquad(df, int_a, int_b, int_a, int_b, epsabs=1e-5, epsrel=1e-5)
assert isclose(ig, gt_partition_func)


Expand Down Expand Up @@ -119,7 +119,6 @@ def test_compile_monotonic_structured_categorical_pc(fold: bool, optimize: bool,
check_discrete_ground_truth(tc, int_tc, compiler.semiring, gt_outputs["evi"], gt_partition_func)


@pytest.mark.slow
def test_compile_monotonic_structured_gaussian_pc():
compiler = TorchCompiler(fold=True, optimize=True, semiring="lse-sum")
sc, gt_outputs, gt_partition_func = build_monotonic_bivariate_gaussian_hadamard_dense_pc(
Expand Down
7 changes: 3 additions & 4 deletions tests/backend/torch/test_compile_circuit_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@ def test_compile_product_integrate_pc_categorical(
assert allclose(compiler.semiring.prod(each_tc_scores, dim=0), scores)


@pytest.mark.slow
def test_compile_product_integrate_pc_gaussian():
compiler = TorchCompiler(semiring="lse-sum", fold=True, optimize=True)
scs, tcs = [], []
last_sc = None
num_products = 3
for i in range(num_products):
sci = build_bivariate_monotonic_structured_cpt_pc(
num_units=2 + i, input_layer="gaussian", normalized=False
num_units=1 + i, input_layer="gaussian", normalized=False
)
tci = compiler.compile(sci)
scs.append(sci)
Expand All @@ -138,8 +137,8 @@ def test_compile_product_integrate_pc_gaussian():
z = z.squeeze()
df = lambda y, x: torch.exp(tc(torch.Tensor([[[x, y]]]))).squeeze()
int_a, int_b = -np.inf, np.inf
ig, err = integrate.dblquad(df, int_a, int_b, int_a, int_b)
assert np.isclose(ig, torch.exp(z).item(), atol=1e-15)
ig, err = integrate.dblquad(df, int_a, int_b, int_a, int_b, epsabs=1e-5, epsrel=1e-5)
assert isclose(ig, torch.exp(z).item())


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tests/backend/torch/test_compile_marginalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimiz
), f"Input: {x}"


@pytest.mark.slow
def test_marginalize_monotonic_pc_gaussian():
compiler = TorchCompiler(fold=True, optimize=True, semiring="lse-sum")
sc, gt_outputs, gt_partition_func = build_monotonic_bivariate_gaussian_hadamard_dense_pc(
Expand Down
50 changes: 50 additions & 0 deletions tests/backend/torch/test_queries/test_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import itertools

import pytest
import torch

from cirkit.backend.torch.circuits import TorchCircuit
from cirkit.backend.torch.compiler import TorchCompiler
from cirkit.backend.torch.queries import SamplingQuery
from tests.floats import allclose
from tests.symbolic.test_utils import build_multivariate_monotonic_structured_cpt_pc


@pytest.mark.parametrize(
"fold,optimize",
itertools.product([False, True], [False, True]),
)
def test_quary_unconditional_sampling(fold: bool, optimize: bool):
compiler = TorchCompiler(semiring="lse-sum", fold=fold, optimize=optimize)
sc = build_multivariate_monotonic_structured_cpt_pc(
num_units=2, input_layer="bernoulli", parameterize=True, normalized=True
)
tc: TorchCircuit = compiler.compile(sc)

# Compute the probabilities
worlds = torch.tensor(list(itertools.product([0, 1], repeat=tc.num_variables))).unsqueeze(
dim=-2
)
assert worlds.shape == (2**tc.num_variables, 1, tc.num_variables)
tc_outputs = tc(worlds)
assert tc_outputs.shape == (2**tc.num_variables, 1, 1)
assert torch.all(torch.isfinite(tc_outputs))
probs = torch.exp(tc_outputs)
probs = probs.squeeze(dim=2).squeeze(dim=1)

# Sample data points unconditionally
num_samples = 1_000_000
query = SamplingQuery(tc)
# samples: (num_samples, C, D)
samples, _ = query(num_samples=num_samples)
assert samples.shape == (num_samples, 1, tc.num_variables)
samples = samples.squeeze(dim=1)

# Map samples to indices of the probabilities computed above
samples_idx = samples * torch.tensor(list(reversed([2**i for i in range(tc.num_variables)])))
samples_idx = torch.sum(samples_idx, dim=-1)

# Compute ratios and compare with the probabilities
_, counts = torch.unique(samples_idx, return_counts=True)
ratios = counts / num_samples
assert allclose(ratios, probs, atol=1e-3)
Loading