Skip to content

Commit

Permalink
fix tests for pytorch 1.8.1 marking inplace operations as unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 10, 2021
1 parent ff9fe96 commit 84b1566
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
17 changes: 8 additions & 9 deletions se3_transformer_pytorch/se3_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim.masked_fill_(~mask, -torch.finfo(sim.dtype).max)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
Expand Down Expand Up @@ -652,7 +652,7 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim.masked_fill_(~mask, -torch.finfo(sim.dtype).max)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
Expand Down Expand Up @@ -690,8 +690,8 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos

if exists(mask):
mask = rearrange(mask, 'b n -> b () n ()')
k.masked_fill_(~mask, -torch.finfo(k.dtype).max)
v.masked_fill_(~mask, 0.)
k = k.masked_fill(~mask, -torch.finfo(k.dtype).max)
v = v.masked_fill(~mask, 0.)

q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
Expand Down Expand Up @@ -1101,7 +1101,6 @@ def __init__(
dim_out = default(dim_out, dim)

assert exists(num_degrees) or exists(hidden_fiber_dict), 'either num_degrees or hidden_fiber_dict must be specified'
assert exists(output_degrees) or exists(out_fiber_dict), 'either output_degrees or out_fiber_dict must be specified'

fiber_in = Fiber.create(input_degrees, dim_in)

Expand Down Expand Up @@ -1242,7 +1241,7 @@ def forward(

next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
adj_indices.masked_fill_(next_degree_mask, degree)
adj_indices = adj_indices.masked_fill(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()

adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
Expand Down Expand Up @@ -1311,18 +1310,18 @@ def forward(
if max_neighbors > neighbors:
print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}')

modified_rel_dist.masked_fill_(~neighbor_mask, max_value)
modified_rel_dist = modified_rel_dist.masked_fill(~neighbor_mask, max_value)

# use sparse neighbor mask to assign priority of bonded

if exists(sparse_neighbor_mask):
modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)
modified_rel_dist = modified_rel_dist.masked_fill(sparse_neighbor_mask, 0.)

# mask out future nodes to high distance if causal turned on

if self.causal:
causal_mask = torch.ones(n, n - 1, device = device).triu().bool()
modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value)
modified_rel_dist = modified_rel_dist.masked_fill(causal_mask[None, ...], max_value)

# if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'se3-transformer-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.8.9',
version = '0.8.10',
license='MIT',
description = 'SE3 Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 84b1566

Please # to comment.