-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathprune_network.py
96 lines (81 loc) · 3.98 KB
/
prune_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
train_data_provider = mnist.train
validation_data_provider = mnist.validation
test_data_provider = mnist.test
from networks import network_dense
from configs import ConfigNetworkDense as config_dense
from configs import ConfigNetworkDensePruned as config_pruned
from utils import plot_utils
from utils import pruning_utils
# at first, create classifier
classifier = network_dense.FullyConnectedClassifier(
input_size=config_pruned.input_size,
n_classes=config_pruned.n_classes,
layer_sizes=config_pruned.layer_sizes,
model_path=config_pruned.model_path,
dropout=config_pruned.dropout,
weight_decay=config_pruned.weight_decay,
activation_fn=config_pruned.activation_fn,
pruning_threshold=config_pruned.pruning_threshold)
# collect tf variables and correspoding optimizer variables
with classifier.graph.as_default():
weight_matrices_tf = classifier.weight_matrices
optimizer_matrices_tf = [v
for v in tf.global_variables()
for w in weight_matrices_tf
if w.name[:-2] in v.name
and 'optimizer' in v.name]
# load previously trained model
# and get values of weights and optimizer variables
weights, optimizer_weights = (classifier
.load_model(config_dense.model_path)
.sess.run([weight_matrices_tf,
optimizer_matrices_tf]))
# plot weights distribution before pruning
weights = classifier.sess.run(weight_matrices_tf)
plot_utils.plot_histogram(weights,
'weights_distribution_before_pruning',
include_zeros=False)
# for each pair (weight matrix + optimizer matrix)
# get a binary mask to get rid of small values.
# Than, based on this mask change the values of
# the weight matrix and the optimizer matrix
for (weight_matrix,
optimizer_matrix,
tf_weight_matrix,
tf_optimizer_matrix) in zip(
weights,
optimizer_weights,
weight_matrices_tf,
optimizer_matrices_tf):
mask = pruning_utils.mask_for_big_values(weight_matrix,
config_pruned.pruning_threshold)
with classifier.graph.as_default():
# update weights
classifier.sess.run(tf_weight_matrix.assign(weight_matrix * mask))
# and corresponding optimizer matrix
classifier.sess.run(tf_optimizer_matrix.assign(optimizer_matrix * mask))
# now, lets look on weights distribution (zero values are excluded)
weights = classifier.sess.run(weight_matrices_tf)
plot_utils.plot_histogram(weights,
'weights_distribution_after_pruning',
include_zeros=False)
accuracy, loss = classifier.evaluate(data_provider=test_data_provider,
batch_size=config_pruned.batch_size)
print('Accuracy on test before fine-tuning: {accuracy}, loss on test: {loss}'.format(
accuracy=accuracy, loss=loss))
# fine-tune classifier
classifier.fit(n_epochs=config_pruned.n_epochs,
batch_size=config_pruned.batch_size,
learning_rate_schedule=config_pruned.learning_rate_schedule,
train_data_provider=train_data_provider,
validation_data_provider=validation_data_provider,
test_data_provider=test_data_provider)
# plot weights distribution again to see the difference
weights = classifier.sess.run(weight_matrices_tf)
plot_utils.plot_histogram(weights,
'weights_distribution_after_fine_tuning',
include_zeros=False)