Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e08d335

Browse files
committedMar 27, 2023
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e4427c7 commit e08d335

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed
 

‎quaterion/loss/fast_ap_loss.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ def get_config_dict(self) -> Dict[str, Any]:
3838
Dict[str, Any]: JSON-serializable dict of params
3939
"""
4040
config = super().get_config_dict()
41-
config.update(
42-
{
43-
"num_bins": self.num_bins
44-
}
45-
)
41+
config.update({"num_bins": self.num_bins})
4642

4743
return config
4844

@@ -68,26 +64,40 @@ def forward(
6864
device = embeddings.device # get the device of the embeddings tensor
6965

7066
# 1. get positive and negative masks
71-
pos_mask = get_anchor_positive_mask(groups).to(device) # (batch_size, batch_size)
72-
neg_mask = get_anchor_negative_mask(groups).to(device) # (batch_size, batch_size)
67+
pos_mask = get_anchor_positive_mask(groups).to(
68+
device
69+
) # (batch_size, batch_size)
70+
neg_mask = get_anchor_negative_mask(groups).to(
71+
device
72+
) # (batch_size, batch_size)
7373
n_pos = torch.sum(pos_mask, dim=1) # Sum over all columns (for each row)
7474

7575
# 2. compute distances from embeddings squared Euclidean distance matrix
76-
embeddings = F.normalize(embeddings, p=2, dim=1).to(device) # normalize embeddings
76+
embeddings = F.normalize(embeddings, p=2, dim=1).to(
77+
device
78+
) # normalize embeddings
7779
dist_matrix = (
7880
self.distance_metric.distance_matrix(embeddings).to(device) ** 2
7981
) # (batch_size, batch_size)
8082

8183
# 3. estimate discrete histograms
8284
histogram_delta = torch.tensor(4.0 / self.num_bins, device=device)
83-
mid_points = torch.linspace(0.0, 4.0, steps=self.num_bins + 1, device=device).view(-1, 1, 1)
85+
mid_points = torch.linspace(
86+
0.0, 4.0, steps=self.num_bins + 1, device=device
87+
).view(-1, 1, 1)
8488

8589
pulse = F.relu(
8690
input=1 - torch.abs(dist_matrix - mid_points) / histogram_delta
87-
).to(device) # max(0, input)
88-
89-
pos_hist = torch.t(torch.sum(pulse * pos_mask, dim=2)).to(device) # positive histograms
90-
neg_hist = torch.t(torch.sum(pulse * neg_mask, dim=2)).to(device) # negative histograms
91+
).to(
92+
device
93+
) # max(0, input)
94+
95+
pos_hist = torch.t(torch.sum(pulse * pos_mask, dim=2)).to(
96+
device
97+
) # positive histograms
98+
neg_hist = torch.t(torch.sum(pulse * neg_mask, dim=2)).to(
99+
device
100+
) # negative histograms
91101

92102
total_pos_hist = torch.cumsum(pos_hist, dim=1).to(device)
93103
total_hist = torch.cumsum(pos_hist + neg_hist, dim=1).to(device)

0 commit comments

Comments
 (0)