-
Notifications
You must be signed in to change notification settings - Fork 292
/
ctgan.py
483 lines (395 loc) · 18.6 KB
/
ctgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
"""CTGAN module."""
import warnings
import numpy as np
import pandas as pd
import torch
from packaging import version
from torch import optim
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
from ctgan.data_sampler import DataSampler
from ctgan.data_transformer import DataTransformer
from ctgan.synthesizers.base import BaseSynthesizer, random_state
class Discriminator(Module):
"""Discriminator for the CTGAN."""
def __init__(self, input_dim, discriminator_dim, pac=10):
super(Discriminator, self).__init__()
dim = input_dim * pac
self.pac = pac
self.pacdim = dim
seq = []
for item in list(discriminator_dim):
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
dim = item
seq += [Linear(dim, 1)]
self.seq = Sequential(*seq)
def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):
"""Compute the gradient penalty."""
alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
alpha = alpha.repeat(1, pac, real_data.size(1))
alpha = alpha.view(-1, real_data.size(1))
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
disc_interpolates = self(interpolates)
gradients = torch.autograd.grad(
outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
return gradient_penalty
def forward(self, input_):
"""Apply the Discriminator to the `input_`."""
assert input_.size()[0] % self.pac == 0
return self.seq(input_.view(-1, self.pacdim))
class Residual(Module):
"""Residual layer for the CTGAN."""
def __init__(self, i, o):
super(Residual, self).__init__()
self.fc = Linear(i, o)
self.bn = BatchNorm1d(o)
self.relu = ReLU()
def forward(self, input_):
"""Apply the Residual layer to the `input_`."""
out = self.fc(input_)
out = self.bn(out)
out = self.relu(out)
return torch.cat([out, input_], dim=1)
class Generator(Module):
"""Generator for the CTGAN."""
def __init__(self, embedding_dim, generator_dim, data_dim):
super(Generator, self).__init__()
dim = embedding_dim
seq = []
for item in list(generator_dim):
seq += [Residual(dim, item)]
dim += item
seq.append(Linear(dim, data_dim))
self.seq = Sequential(*seq)
def forward(self, input_):
"""Apply the Generator to the `input_`."""
data = self.seq(input_)
return data
class CTGAN(BaseSynthesizer):
"""Conditional Table GAN Synthesizer.
This is the core class of the CTGAN project, where the different components
are orchestrated together.
For more details about the process, please check the [Modeling Tabular data using
Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
Args:
embedding_dim (int):
Size of the random sample passed to the Generator. Defaults to 128.
generator_dim (tuple or list of ints):
Size of the output samples for each one of the Residuals. A Residual Layer
will be created for each one of the values provided. Defaults to (256, 256).
discriminator_dim (tuple or list of ints):
Size of the output samples for each one of the Discriminator Layers. A Linear Layer
will be created for each one of the values provided. Defaults to (256, 256).
generator_lr (float):
Learning rate for the generator. Defaults to 2e-4.
generator_decay (float):
Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
discriminator_lr (float):
Learning rate for the discriminator. Defaults to 2e-4.
discriminator_decay (float):
Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int):
Number of data samples to process in each step.
discriminator_steps (int):
Number of discriminator updates to do for each generator update.
From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
default is 5. Default used is 1 to match original CTGAN implementation.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
verbose (boolean):
Whether to have print statements for progress results. Defaults to ``False``.
epochs (int):
Number of training epochs. Defaults to 300.
pac (int):
Number of samples to group together when applying the discriminator.
Defaults to 10.
cuda (bool):
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
"""
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
discriminator_decay=1e-6, batch_size=500, discriminator_steps=1,
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True):
assert batch_size % 2 == 0
self._embedding_dim = embedding_dim
self._generator_dim = generator_dim
self._discriminator_dim = discriminator_dim
self._generator_lr = generator_lr
self._generator_decay = generator_decay
self._discriminator_lr = discriminator_lr
self._discriminator_decay = discriminator_decay
self._batch_size = batch_size
self._discriminator_steps = discriminator_steps
self._log_frequency = log_frequency
self._verbose = verbose
self._epochs = epochs
self.pac = pac
if not cuda or not torch.cuda.is_available():
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'
self._device = torch.device(device)
self._transformer = None
self._data_sampler = None
self._generator = None
@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
For more details about the issue:
https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
Args:
logits […, num_features]:
Unnormalized log probabilities
tau:
Non-negative scalar temperature
hard (bool):
If True, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd
dim (int):
A dimension along which softmax will be computed. Default: -1.
Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
"""
if version.parse(torch.__version__) < version.parse('1.2.0'):
for i in range(10):
transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard,
eps=eps, dim=dim)
if not torch.isnan(transformed).any():
return transformed
raise ValueError('gumbel_softmax returning NaN.')
return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
def _apply_activate(self, data):
"""Apply proper activation function to the output of the generator."""
data_t = []
st = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if span_info.activation_fn == 'tanh':
ed = st + span_info.dim
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif span_info.activation_fn == 'softmax':
ed = st + span_info.dim
transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2)
data_t.append(transformed)
st = ed
else:
raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')
return torch.cat(data_t, dim=1)
def _cond_loss(self, data, c, m):
"""Compute the cross entropy loss on the fixed discrete column."""
loss = []
st = 0
st_c = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if len(column_info) != 1 or span_info.activation_fn != 'softmax':
# not discrete column
st += span_info.dim
else:
ed = st + span_info.dim
ed_c = st_c + span_info.dim
tmp = functional.cross_entropy(
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction='none'
)
loss.append(tmp)
st = ed
st_c = ed_c
loss = torch.stack(loss, dim=1) # noqa: PD013
return (loss * m).sum() / data.size()[0]
def _validate_discrete_columns(self, train_data, discrete_columns):
"""Check whether ``discrete_columns`` exists in ``train_data``.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
if isinstance(train_data, pd.DataFrame):
invalid_columns = set(discrete_columns) - set(train_data.columns)
elif isinstance(train_data, np.ndarray):
invalid_columns = []
for column in discrete_columns:
if column < 0 or column >= train_data.shape[1]:
invalid_columns.append(column)
else:
raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
if invalid_columns:
raise ValueError(f'Invalid columns found: {invalid_columns}')
@random_state
def fit(self, train_data, discrete_columns=(), epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
self._validate_discrete_columns(train_data, discrete_columns)
if epochs is None:
epochs = self._epochs
else:
warnings.warn(
('`epochs` argument in `fit` method has been deprecated and will be removed '
'in a future version. Please pass `epochs` to the constructor instead'),
DeprecationWarning
)
self._transformer = DataTransformer()
self._transformer.fit(train_data, discrete_columns)
train_data = self._transformer.transform(train_data)
self._data_sampler = DataSampler(
train_data,
self._transformer.output_info_list,
self._log_frequency)
data_dim = self._transformer.output_dimensions
self._generator = Generator(
self._embedding_dim + self._data_sampler.dim_cond_vec(),
self._generator_dim,
data_dim
).to(self._device)
discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(),
self._discriminator_dim,
pac=self.pac
).to(self._device)
optimizerG = optim.Adam(
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
weight_decay=self._generator_decay
)
optimizerD = optim.Adam(
discriminator.parameters(), lr=self._discriminator_lr,
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
)
mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
std = mean + 1
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in range(epochs):
for id_ in range(steps_per_epoch):
for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)
condvec = self._data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = self._data_sampler.sample_data(self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
m1 = torch.from_numpy(m1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = self._data_sampler.sample_data(
self._batch_size, col[perm], opt[perm])
c2 = c1[perm]
fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
real = torch.from_numpy(real.astype('float32')).to(self._device)
if c1 is not None:
fake_cat = torch.cat([fakeact, c1], dim=1)
real_cat = torch.cat([real, c2], dim=1)
else:
real_cat = real
fake_cat = fakeact
y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)
pen = discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device, self.pac)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
optimizerD.step()
fakez = torch.normal(mean=mean, std=std)
condvec = self._data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
m1 = torch.from_numpy(m1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
if c1 is not None:
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = discriminator(fakeact)
if condvec is None:
cross_entropy = 0
else:
cross_entropy = self._cond_loss(fake, c1, m1)
loss_g = -torch.mean(y_fake) + cross_entropy
optimizerG.zero_grad()
loss_g.backward()
optimizerG.step()
if self._verbose:
print(f'Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},' # noqa: T001
f'Loss D: {loss_d.detach().cpu(): .4f}',
flush=True)
@random_state
def sample(self, n, condition_column=None, condition_value=None):
"""Sample data similar to the training data.
Choosing a condition_column and condition_value will increase the probability of the
discrete condition_value happening in the condition_column.
Args:
n (int):
Number of rows to sample.
condition_column (string):
Name of a discrete column.
condition_value (string):
Name of the category in the condition_column which we wish to increase the
probability of happening.
Returns:
numpy.ndarray or pandas.DataFrame
"""
if condition_column is not None and condition_value is not None:
condition_info = self._transformer.convert_column_name_value_to_id(
condition_column, condition_value)
global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
condition_info, self._batch_size)
else:
global_condition_vec = None
steps = n // self._batch_size + 1
data = []
for i in range(steps):
mean = torch.zeros(self._batch_size, self._embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self._device)
if global_condition_vec is not None:
condvec = global_condition_vec.copy()
else:
condvec = self._data_sampler.sample_original_condvec(self._batch_size)
if condvec is None:
pass
else:
c1 = condvec
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())
data = np.concatenate(data, axis=0)
data = data[:n]
return self._transformer.inverse_transform(data)
def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)