-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathexp_replay.py
196 lines (162 loc) · 7.81 KB
/
exp_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
import numpy as np
from importlib import import_module
from .default import NormalNN
from .regularization import SI, L2, EWC, MAS
from dataloaders.wrapper import Storage
class Naive_Rehearsal(NormalNN):
def __init__(self, agent_config):
super(Naive_Rehearsal, self).__init__(agent_config)
self.task_count = 0
self.memory_size = 1000
self.task_memory = {}
self.skip_memory_concatenation = False
def learn_batch(self, train_loader, val_loader=None):
# 1.Combine training set
if self.skip_memory_concatenation:
new_train_loader = train_loader
else: # default
dataset_list = []
for storage in self.task_memory.values():
dataset_list.append(storage)
dataset_list *= max(len(train_loader.dataset)//self.memory_size,1) # Let old data: new data = 1:1
dataset_list.append(train_loader.dataset)
dataset = torch.utils.data.ConcatDataset(dataset_list)
new_train_loader = torch.utils.data.DataLoader(dataset,
batch_size=train_loader.batch_size,
shuffle=True,
num_workers=train_loader.num_workers)
# 2.Update model as normal
super(Naive_Rehearsal, self).learn_batch(new_train_loader, val_loader)
# 3.Randomly decide the images to stay in the memory
self.task_count += 1
# (a) Decide the number of samples for being saved
num_sample_per_task = self.memory_size // self.task_count
num_sample_per_task = min(len(train_loader.dataset),num_sample_per_task)
# (b) Reduce current exemplar set to reserve the space for the new dataset
for storage in self.task_memory.values():
storage.reduce(num_sample_per_task)
# (c) Randomly choose some samples from new task and save them to the memory
randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data
self.task_memory[self.task_count] = Storage(train_loader.dataset, randind)
class Naive_Rehearsal_SI(Naive_Rehearsal, SI):
def __init__(self, agent_config):
super(Naive_Rehearsal_SI, self).__init__(agent_config)
class Naive_Rehearsal_L2(Naive_Rehearsal, L2):
def __init__(self, agent_config):
super(Naive_Rehearsal_L2, self).__init__(agent_config)
class Naive_Rehearsal_EWC(Naive_Rehearsal, EWC):
def __init__(self, agent_config):
super(Naive_Rehearsal_EWC, self).__init__(agent_config)
self.online_reg = True # Online EWC
class Naive_Rehearsal_MAS(Naive_Rehearsal, MAS):
def __init__(self, agent_config):
super(Naive_Rehearsal_MAS, self).__init__(agent_config)
class GEM(Naive_Rehearsal):
"""
@inproceedings{GradientEpisodicMemory,
title={Gradient Episodic Memory for Continual Learning},
author={Lopez-Paz, David and Ranzato, Marc'Aurelio},
booktitle={NIPS},
year={2017},
url={https://arxiv.org/abs/1706.08840}
}
"""
def __init__(self, agent_config):
super(GEM, self).__init__(agent_config)
self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # For convenience
self.task_grads = {}
self.quadprog = import_module('quadprog')
self.task_mem_cache = {}
def grad_to_vector(self):
vec = []
for n,p in self.params.items():
if p.grad is not None:
vec.append(p.grad.view(-1))
else:
# Part of the network might has no grad, fill zero for those terms
vec.append(p.data.clone().fill_(0).view(-1))
return torch.cat(vec)
def vector_to_grad(self, vec):
# Overwrite current param.grad by slicing the values in vec (flatten grad)
pointer = 0
for n, p in self.params.items():
# The length of the parameter
num_param = p.numel()
if p.grad is not None:
# Slice the vector, reshape it, and replace the old data of the grad
p.grad.copy_(vec[pointer:pointer + num_param].view_as(p))
# Part of the network might has no grad, ignore those terms
# Increment the pointer
pointer += num_param
def project2cone2(self, gradient, memories):
"""
Solves the GEM dual QP described in the paper given a proposed
gradient "gradient", and a memory of task gradients "memories".
Overwrites "gradient" with the final projected update.
input: gradient, p-vector
input: memories, (t * p)-vector
output: x, p-vector
Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70
"""
margin = self.config['reg_coef']
memories_np = memories.cpu().contiguous().double().numpy()
gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
t = memories_np.shape[0]
#print(memories_np.shape, gradient_np.shape)
P = np.dot(memories_np, memories_np.transpose())
P = 0.5 * (P + P.transpose())
q = np.dot(memories_np, gradient_np) * -1
G = np.eye(t)
P = P + G * 0.001
h = np.zeros(t) + margin
v = self.quadprog.solve_qp(P, q, G, h)[0]
x = np.dot(v, memories_np) + gradient_np
new_grad = torch.Tensor(x).view(-1)
if self.gpu:
new_grad = new_grad.cuda()
return new_grad
def learn_batch(self, train_loader, val_loader=None):
# Update model as normal
super(GEM, self).learn_batch(train_loader, val_loader)
# Cache the data for faster processing
for t, mem in self.task_memory.items():
# Concatenate all data in each task
mem_loader = torch.utils.data.DataLoader(mem,
batch_size=len(mem),
shuffle=False,
num_workers=2)
assert len(mem_loader)==1,'The length of mem_loader should be 1'
for i, (mem_input, mem_target, mem_task) in enumerate(mem_loader):
if self.gpu:
mem_input = mem_input.cuda()
mem_target = mem_target.cuda()
self.task_mem_cache[t] = {'data':mem_input,'target':mem_target,'task':mem_task}
def update_model(self, inputs, targets, tasks):
# compute gradient on previous tasks
if self.task_count > 0:
for t,mem in self.task_memory.items():
self.zero_grad()
# feed the data from memory and collect the gradients
mem_out = self.forward(self.task_mem_cache[t]['data'])
mem_loss = self.criterion(mem_out, self.task_mem_cache[t]['target'], self.task_mem_cache[t]['task'])
mem_loss.backward()
# Store the grads
self.task_grads[t] = self.grad_to_vector()
# now compute the grad on the current minibatch
out = self.forward(inputs)
loss = self.criterion(out, targets, tasks)
self.optimizer.zero_grad()
loss.backward()
# check if gradient violates constraints
if self.task_count > 0:
current_grad_vec = self.grad_to_vector()
mem_grad_vec = torch.stack(list(self.task_grads.values()))
dotp = current_grad_vec * mem_grad_vec
dotp = dotp.sum(dim=1)
if (dotp < 0).sum() != 0:
new_grad = self.project2cone2(current_grad_vec, mem_grad_vec)
# copy gradients back
self.vector_to_grad(new_grad)
self.optimizer.step()
return loss.detach(), out