-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathFSHA_distcor.py
98 lines (79 loc) · 4.15 KB
/
FSHA_distcor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import tensorflow as tf
import numpy as np
import tqdm
import datasets, architectures
import defense
from FSHA import *
class FSHA_dc(FSHA):
@tf.function
def train_step(self, x_private, x_public, label_private, label_public):
with tf.GradientTape(persistent=True) as tape:
#### Virtually, ON THE CLIENT SIDE:
# clients' smashed data
z_private = self.f(x_private, training=True)
def_loss = 0.
if 'alpha1' in self.hparams:
w_dc = self.hparams['alpha1']
print("With client \alpha_1: %f" % w_dc)
def_loss = defense.dist_corr(x_private, z_private) * w_dc
####################################
#### SERVER-SIDE:
# map to data space (for evaluation and style loss)
rec_x_private = self.decoder(z_private, training=True)
## adversarial loss (f's output must similar be to \tilde{f}'s output):
adv_private_logits = self.D(z_private)
if self.hparams['WGAN']:
print("Use WGAN loss")
f_loss = tf.reduce_mean(adv_private_logits)
else:
f_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(adv_private_logits), adv_private_logits, from_logits=True))
##
if['alpha2' in self.hparams]:
w_loss = self.hparams['alpha2']
print("Scale adversarial loss server %f" % w_loss)
f_loss *= w_loss
adv_f_loss = f_loss
z_public = self.tilde_f(x_public, training=True)
# invertibility loss
rec_x_public = self.decoder(z_public, training=True)
public_rec_loss = distance_data_loss(x_public, rec_x_public)
tilde_f_loss = public_rec_loss
# discriminator on attacker's feature-space
adv_public_logits = self.D(z_public, training=True)
if self.hparams['WGAN']:
loss_discr_true = tf.reduce_mean( adv_public_logits )
loss_discr_fake = -tf.reduce_mean( adv_private_logits)
# discriminator's loss
D_loss = loss_discr_true + loss_discr_fake
else:
loss_discr_true = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(adv_public_logits), adv_public_logits, from_logits=True))
loss_discr_fake = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.zeros_like(adv_private_logits), adv_private_logits, from_logits=True))
# discriminator's loss
D_loss = (loss_discr_true + loss_discr_fake) / 2
if 'gradient_penalty' in self.hparams:
print("Use GP")
w = float(self.hparams['gradient_penalty'])
D_gradient_penalty = self.gradient_penalty(z_private, z_public)
D_loss += D_gradient_penalty * w
##################################################################
## attack validation #####################
loss_c_verification = distance_data(x_private, rec_x_private)
############################################
##################################################################
# train attacker's autoencoder on public data
var = self.tilde_f.trainable_variables + self.decoder.trainable_variables
gradients = tape.gradient(tilde_f_loss, var)
self.optimizer1.apply_gradients(zip(gradients, var))
# train discriminator
var = self.D.trainable_variables
gradients = tape.gradient(D_loss, var)
self.optimizer2.apply_gradients(zip(gradients, var))
# train client's network privacy loss
var = self.f.trainable_variables
gradients = tape.gradient(def_loss, var)
self.optimizer0.apply_gradients(zip(gradients, var))
# train client's network
var = self.f.trainable_variables
gradients = tape.gradient(f_loss, var)
self.optimizer0.apply_gradients(zip(gradients, var))
return adv_f_loss, tilde_f_loss, D_loss, loss_c_verification