diff --git a/se3_transformer_pytorch/se3_transformer_pytorch.py b/se3_transformer_pytorch/se3_transformer_pytorch.py index 4bc59c7..06eccca 100644 --- a/se3_transformer_pytorch/se3_transformer_pytorch.py +++ b/se3_transformer_pytorch/se3_transformer_pytorch.py @@ -517,7 +517,7 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos if exists(neighbor_mask): num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) - sim.masked_fill_(~mask, -torch.finfo(sim.dtype).max) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim = -1) out = einsum('b h i j, b h i j d m -> b h i d m', attn, v) @@ -652,7 +652,7 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos if exists(neighbor_mask): num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) - sim.masked_fill_(~mask, -torch.finfo(sim.dtype).max) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim = -1) out = einsum('b h i j, b i j d m -> b h i d m', attn, v) @@ -690,8 +690,8 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos if exists(mask): mask = rearrange(mask, 'b n -> b () n ()') - k.masked_fill_(~mask, -torch.finfo(k.dtype).max) - v.masked_fill_(~mask, 0.) + k = k.masked_fill(~mask, -torch.finfo(k.dtype).max) + v = v.masked_fill(~mask, 0.) q = q.softmax(dim = -1) k = k.softmax(dim = -2) @@ -1101,7 +1101,6 @@ def __init__( dim_out = default(dim_out, dim) assert exists(num_degrees) or exists(hidden_fiber_dict), 'either num_degrees or hidden_fiber_dict must be specified' - assert exists(output_degrees) or exists(out_fiber_dict), 'either output_degrees or out_fiber_dict must be specified' fiber_in = Fiber.create(input_degrees, dim_in) @@ -1242,7 +1241,7 @@ def forward( next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() - adj_indices.masked_fill_(next_degree_mask, degree) + adj_indices = adj_indices.masked_fill(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) @@ -1311,18 +1310,18 @@ def forward( if max_neighbors > neighbors: print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}') - modified_rel_dist.masked_fill_(~neighbor_mask, max_value) + modified_rel_dist = modified_rel_dist.masked_fill(~neighbor_mask, max_value) # use sparse neighbor mask to assign priority of bonded if exists(sparse_neighbor_mask): - modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.) + modified_rel_dist = modified_rel_dist.masked_fill(sparse_neighbor_mask, 0.) # mask out future nodes to high distance if causal turned on if self.causal: causal_mask = torch.ones(n, n - 1, device = device).triu().bool() - modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value) + modified_rel_dist = modified_rel_dist.masked_fill(causal_mask[None, ...], max_value) # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix diff --git a/setup.py b/setup.py index 522c177..11a72ce 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'se3-transformer-pytorch', packages = find_packages(), include_package_data = True, - version = '0.8.9', + version = '0.8.10', license='MIT', description = 'SE3 Transformer - Pytorch', author = 'Phil Wang',