-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCKA.py
194 lines (151 loc) · 7.32 KB
/
CKA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
def gram_linear(x):
"""Compute Gram (kernel) matrix for a linear kernel.
Args:
x: A num_examples x num_features matrix of features.
Returns:
A num_examples x num_examples Gram matrix of examples.
"""
return torch.mm(x, torch.t(x))
def gram_rbf(x, threshold=1.0):
"""Compute Gram (kernel) matrix for an RBF kernel.
Args:
x: A num_examples x num_features matrix of features.
threshold: Fraction of median Euclidean distance to use as RBF kernel
bandwidth. (This is the heuristic we use in the paper. There are other
possible ways to set the bandwidth; we didn't try them.)
Returns:
A num_examples x num_examples Gram matrix of examples.
"""
dot_products = torch.mm(x, torch.t(x))
sq_norms = torch.diag(dot_products)
sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
sq_median_distance = torch.median(sq_distances)
return torch.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance + 1e-7))
def center_gram(gram, unbiased=False):
"""Center a symmetric Gram matrix.
This is equvialent to centering the (possibly infinite-dimensional) features
induced by the kernel before computing the Gram matrix.
Args:
gram: A num_examples x num_examples symmetric matrix.
unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
estimate of HSIC. Note that this estimator may be negative.
Returns:
A symmetric matrix with centered columns and rows.
"""
if not float(torch.max(torch.t(gram) - gram)) < 1e-6:
raise ValueError('Input must be a symmetric matrix.')
gram = gram.clone()
if unbiased:
# This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
# L. (2014). Partial distance correlation with methods for dissimilarities.
# The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
# stable than the alternative from Song et al. (2007).
n = gram.shape[0]
gram = gram - torch.diag(torch.diag(gram))
means = torch.sum(gram, 0).float() / (n - 2)
means -= torch.sum(means) / (2 * (n - 1))
gram -= means[:, None]
gram -= means[None, :]
gram = gram - torch.diag(torch.diag(gram))
else:
means = torch.mean(gram, 0)
means -= torch.mean(means) / 2
gram -= means[:, None]
gram -= means[None, :]
return gram
def cka(gram_x, gram_y, debiased=False):
"""Compute CKA.
Args:
gram_x: A num_examples x num_examples Gram matrix.
gram_y: A num_examples x num_examples Gram matrix.
debiased: Use unbiased estimator of HSIC. CKA may still be biased.
Returns:
The value of CKA between X and Y.
"""
gram_x = center_gram(gram_x, unbiased=debiased)
gram_y = center_gram(gram_y, unbiased=debiased)
# Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
# n*(n-3) (unbiased variant), but this cancels for CKA.
scaled_hsic = torch.mm(gram_x.view(1, -1), torch.t(gram_y.view(1, -1)))
normalization_x = torch.norm(gram_x)
normalization_y = torch.norm(gram_y)
return scaled_hsic / (normalization_x * normalization_y)
def linear_CKA(gram_x, gram_y, debiased=False):
"""Compute CKA.
Args:
gram_x: A num_examples x num_examples Gram matrix.
gram_y: A num_examples x num_examples Gram matrix.
debiased: Use unbiased estimator of HSIC. CKA may still be biased.
Returns:
The value of CKA between X and Y.
"""
gram_x = center_gram(gram_linear(gram_x), unbiased=debiased)
gram_y = center_gram(gram_linear(gram_y), unbiased=debiased)
# Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
# n*(n-3) (unbiased variant), but this cancels for CKA.
scaled_hsic = torch.mm(gram_x.view(1, -1), torch.t(gram_y.view(1, -1)))
normalization_x = torch.norm(gram_x)
normalization_y = torch.norm(gram_y)
return scaled_hsic[0, 0] / (normalization_x * normalization_y + 1e-7)
# return torch.abs(scaled_hsic[0, 0]), torch.abs(normalization_x * normalization_y)
def rbf_CKA(gram_x, gram_y, debiased=False):
"""Compute CKA.
Args:
gram_x: A num_examples x num_examples Gram matrix.
gram_y: A num_examples x num_examples Gram matrix.
debiased: Use unbiased estimator of HSIC. CKA may still be biased.
Returns:
The value of CKA between X and Y.
"""
gram_x = center_gram(gram_rbf(gram_x), unbiased=debiased)
gram_y = center_gram(gram_rbf(gram_y), unbiased=debiased)
# Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
# n*(n-3) (unbiased variant), but this cancels for CKA.
scaled_hsic = torch.mm(gram_x.view(1, -1), torch.t(gram_y.view(1, -1)))
normalization_x = torch.norm(gram_x)
normalization_y = torch.norm(gram_y)
return scaled_hsic / (normalization_x * normalization_y)
def _debiased_dot_product_similarity_helper(
xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,
n):
"""Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
# This formula can be derived by manipulating the unbiased estimator from
# Song et al. (2007).
return (
xty - n / (n - 2.) * torch.sum(sum_squared_rows_x * sum_squared_rows_y)
+ squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))
def feature_space_linear_cka(features_x, features_y, debiased=False):
"""Compute CKA with a linear kernel, in feature space.
This is typically faster than computing the Gram matrix when there are fewer
features than examples.
Args:
features_x: A num_examples x num_features matrix of features.
features_y: A num_examples x num_features matrix of features.
debiased: Use unbiased estimator of dot product similarity. CKA may still be
biased. Note that this estimator may be negative.
Returns:
The value of CKA between X and Y.
"""
features_x = features_x - torch.mean(features_x, 0, True)
features_y = features_y - torch.mean(features_y, 0, True)
dot_product_similarity = torch.norm(torch.mm(torch.t(features_x), features_y)) ** 2
normalization_x = torch.norm(torch.mm(torch.t(features_x), features_x))
normalization_y = torch.norm(torch.mm(torch.t(features_y), features_y))
if debiased:
n = features_x.shape[0]
# Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
sum_squared_rows_x = torch.einsum('ij,ij->i', features_x, features_x)
sum_squared_rows_y = torch.einsum('ij,ij->i', features_y, features_y)
squared_norm_x = torch.sum(sum_squared_rows_x)
squared_norm_y = torch.sum(sum_squared_rows_y)
dot_product_similarity = _debiased_dot_product_similarity_helper(
dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
squared_norm_x, squared_norm_y, n)
normalization_x = torch.sqrt(_debiased_dot_product_similarity_helper(
normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
squared_norm_x, squared_norm_x, n))
normalization_y = torch.sqrt(_debiased_dot_product_similarity_helper(
normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
squared_norm_y, squared_norm_y, n))
return dot_product_similarity / (normalization_x * normalization_y)