-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathDisco.py
46 lines (33 loc) · 1.95 KB
/
Disco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
def distance_corr(var_1,var_2,normedweight,power=1):
"""var_1: First variable to decorrelate (eg mass)
var_2: Second variable to decorrelate (eg classifier output)
normedweight: Per-example weight. Sum of weights should add up to N (where N is the number of examples)
power: Exponent used in calculating the distance correlation
va1_1, var_2 and normedweight should all be 1D torch tensors with the same number of entries
Usage: Add to your loss function. total_loss = BCE_loss + lambda * distance_corr
"""
xx = var_1.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))
yy = var_1.repeat(len(var_1),1).view(len(var_1),len(var_1))
amat = (xx-yy).abs()
xx = var_2.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))
yy = var_2.repeat(len(var_2),1).view(len(var_2),len(var_2))
bmat = (xx-yy).abs()
amatavg = torch.mean(amat*normedweight,dim=1)
Amat=amat-amatavg.repeat(len(var_1),1).view(len(var_1),len(var_1))\
-amatavg.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))\
+torch.mean(amatavg*normedweight)
bmatavg = torch.mean(bmat*normedweight,dim=1)
Bmat=bmat-bmatavg.repeat(len(var_2),1).view(len(var_2),len(var_2))\
-bmatavg.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))\
+torch.mean(bmatavg*normedweight)
ABavg = torch.mean(Amat*Bmat*normedweight,dim=1)
AAavg = torch.mean(Amat*Amat*normedweight,dim=1)
BBavg = torch.mean(Bmat*Bmat*normedweight,dim=1)
if(power==1):
dCorr=(torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight)))
elif(power==2):
dCorr=(torch.mean(ABavg*normedweight))**2/(torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))
else:
dCorr=((torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))))**power
return dCorr