Skip to content

Commit

Permalink
black formatting & condition pre-saving
Browse files Browse the repository at this point in the history
  • Loading branch information
gengala committed Nov 7, 2024
1 parent 752986d commit fe1d8a6
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def __init__(
raise ValueError(f"The number of folds and shape of 'probs' must match the layer's")
self.probs = probs
self.logits = logits
self.idx_mode = (
len(torch.unique(self.scope_idx)) > 4096 or self.num_categories > 256
)

def _valid_parameter_shape(self, p: TorchParameter) -> bool:
if p.num_folds != self.num_folds:
Expand Down Expand Up @@ -324,14 +327,25 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
if x.is_floating_point():
x = x.long() # The input to Categorical should be discrete
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if len(torch.unique(self.scope_idx)) > 4096 or self.num_categories > 256:
if self.idx_mode:
if self.num_channels == 1:
x = logits[:, :, 0, :].transpose(1, 2)[range(self.num_folds), x[:, 0, :, 0].t()].transpose(0, 1)
x = (

Check warning on line 332 in cirkit/backend/torch/layers/input.py

View check run for this annotation

Codecov / codecov/patch

cirkit/backend/torch/layers/input.py#L332

Added line #L332 was not covered by tests
logits[:, :, 0, :]
.transpose(1, 2)[range(self.num_folds), x[:, 0, :, 0].t()]
.transpose(0, 1)
)
else:
x = x[..., 0].permute(2, 0, 1)
x = logits[
torch.arange(self.num_folds).unsqueeze(1), :, torch.arange(self.num_channels).unsqueeze(0),
x].sum(2).transpose(0, 1)
x = (

Check warning on line 339 in cirkit/backend/torch/layers/input.py

View check run for this annotation

Codecov / codecov/patch

cirkit/backend/torch/layers/input.py#L338-L339

Added lines #L338 - L339 were not covered by tests
logits[
torch.arange(self.num_folds).unsqueeze(1),
:,
torch.arange(self.num_channels).unsqueeze(0),
x,
]
.sum(2)
.transpose(0, 1)
)
else:
x = F.one_hot(x, self.num_categories) # (F, C, B, 1, num_categories)
x = x.squeeze(dim=3) # (F, C, B, num_categories)
Expand Down

0 comments on commit fe1d8a6

Please # to comment.