-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_celeba_01_gan.py
180 lines (157 loc) · 6.39 KB
/
train_celeba_01_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import random
from pathlib import Path
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import wandb
from va.celeba.discriminator import Discriminator
from va.celeba.generator import Generator
from va.celeba.utils import weights_init
def load_data(dataroot, image_size, batch_size, workers):
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(
root=dataroot,
transform=transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size,
shuffle=True, num_workers=workers
)
return dataloader
def main():
wandb.init()
data_dir = "data/celeba/"
exp = Path("exp/dcgan/")
exp.mkdir(parents=True, exist_ok=True)
# Hyper-parameters
# Set random seed for reproducibility
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)
image_size = 64
workers = 16 # Number of workers for dataloader
batch_size = 128 # Batch size during training
nc = 3 # Number of channels in the training images. For color images this is 3
nz = 100 # Size of z latent vector (i.e. size of generator input)
ngf = 64 # Size of feature maps in generator
ndf = 64 # Size of feature maps in discriminator
num_epochs = 100 # Number of training epochs
lr = 0.0002 # Learning rate for optimizers
beta1 = 0.5 # Beta1 hyperparam for Adam optimizers
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
wandb.config = {
"learning_rate": lr,
"epochs": num_epochs,
"batch_size": batch_size
}
dataloader = load_data(data_dir, image_size, batch_size, workers)
# Create the Generator
netG = Generator(nz, ngf, nc).to(device)
wandb.watch(netG)
# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.02.
netG.apply(weights_init)
print(netG)
# Create the Discriminator
netD = Discriminator(ndf, nc).to(device)
wandb.watch(netD)
# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
netD.apply(weights_init)
print(netD)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
img_list = []
iters = 0
for epoch in range(num_epochs):
begin = time.time()
netG.train()
netD.train()
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
stats = {"lossG": errG.item(), "lossD": errD.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
if (iters % 1000 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
stats["generated"] = wandb.Image(img_list[-1])
wandb.log(stats)
iters += 1
filename = exp / f"generator-{epoch}.pt"
torch.save(netG.state_dict(), str(filename))
end = time.time()
print("saved generator to", filename)
print("epoch took", end - begin, "seconds")
if __name__=="__main__":
main()