@@ -38,11 +38,7 @@ def get_config_dict(self) -> Dict[str, Any]:
38
38
Dict[str, Any]: JSON-serializable dict of params
39
39
"""
40
40
config = super ().get_config_dict ()
41
- config .update (
42
- {
43
- "num_bins" : self .num_bins
44
- }
45
- )
41
+ config .update ({"num_bins" : self .num_bins })
46
42
47
43
return config
48
44
@@ -68,26 +64,40 @@ def forward(
68
64
device = embeddings .device # get the device of the embeddings tensor
69
65
70
66
# 1. get positive and negative masks
71
- pos_mask = get_anchor_positive_mask (groups ).to (device ) # (batch_size, batch_size)
72
- neg_mask = get_anchor_negative_mask (groups ).to (device ) # (batch_size, batch_size)
67
+ pos_mask = get_anchor_positive_mask (groups ).to (
68
+ device
69
+ ) # (batch_size, batch_size)
70
+ neg_mask = get_anchor_negative_mask (groups ).to (
71
+ device
72
+ ) # (batch_size, batch_size)
73
73
n_pos = torch .sum (pos_mask , dim = 1 ) # Sum over all columns (for each row)
74
74
75
75
# 2. compute distances from embeddings squared Euclidean distance matrix
76
- embeddings = F .normalize (embeddings , p = 2 , dim = 1 ).to (device ) # normalize embeddings
76
+ embeddings = F .normalize (embeddings , p = 2 , dim = 1 ).to (
77
+ device
78
+ ) # normalize embeddings
77
79
dist_matrix = (
78
80
self .distance_metric .distance_matrix (embeddings ).to (device ) ** 2
79
81
) # (batch_size, batch_size)
80
82
81
83
# 3. estimate discrete histograms
82
84
histogram_delta = torch .tensor (4.0 / self .num_bins , device = device )
83
- mid_points = torch .linspace (0.0 , 4.0 , steps = self .num_bins + 1 , device = device ).view (- 1 , 1 , 1 )
85
+ mid_points = torch .linspace (
86
+ 0.0 , 4.0 , steps = self .num_bins + 1 , device = device
87
+ ).view (- 1 , 1 , 1 )
84
88
85
89
pulse = F .relu (
86
90
input = 1 - torch .abs (dist_matrix - mid_points ) / histogram_delta
87
- ).to (device ) # max(0, input)
88
-
89
- pos_hist = torch .t (torch .sum (pulse * pos_mask , dim = 2 )).to (device ) # positive histograms
90
- neg_hist = torch .t (torch .sum (pulse * neg_mask , dim = 2 )).to (device ) # negative histograms
91
+ ).to (
92
+ device
93
+ ) # max(0, input)
94
+
95
+ pos_hist = torch .t (torch .sum (pulse * pos_mask , dim = 2 )).to (
96
+ device
97
+ ) # positive histograms
98
+ neg_hist = torch .t (torch .sum (pulse * neg_mask , dim = 2 )).to (
99
+ device
100
+ ) # negative histograms
91
101
92
102
total_pos_hist = torch .cumsum (pos_hist , dim = 1 ).to (device )
93
103
total_hist = torch .cumsum (pos_hist + neg_hist , dim = 1 ).to (device )
0 commit comments