-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathattention_module.py
120 lines (100 loc) · 5.02 KB
/
attention_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def se_block(residual, name, ratio=8):
"""Contains the implementation of Squeeze-and-Excitation(SE) block.
As described in https://arxiv.org/abs/1709.01507.
"""
kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
bias_initializer = tf.constant_initializer(value=0.0)
with tf.variable_scope(name):
channel = residual.get_shape()[-1]
# Global average pooling
squeeze = tf.reduce_mean(residual, axis=[1,2], keepdims=True)
assert squeeze.get_shape()[1:] == (1,1,channel)
excitation = tf.layers.dense(inputs=squeeze,
units=channel//ratio,
activation=tf.nn.relu,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
name='bottleneck_fc')
assert excitation.get_shape()[1:] == (1,1,channel//ratio)
excitation = tf.layers.dense(inputs=excitation,
units=channel,
activation=tf.nn.sigmoid,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
name='recover_fc')
assert excitation.get_shape()[1:] == (1,1,channel)
# top = tf.multiply(bottom, se, name='scale')
scale = residual * excitation
return scale
def cbam_block(input_feature, name, ratio=8):
"""Contains the implementation of Convolutional Block Attention Module(CBAM) block.
As described in https://arxiv.org/abs/1807.06521.
"""
with tf.variable_scope(name):
attention_feature = channel_attention(input_feature, 'ch_at', ratio)
attention_feature = spatial_attention(attention_feature, 'sp_at')
print ("CBAM Hello")
return attention_feature
def channel_attention(input_feature, name, ratio=8):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
bias_initializer = tf.constant_initializer(value=0.0)
with tf.variable_scope(name):
channel = input_feature.get_shape()[-1]
avg_pool = tf.reduce_mean(input_feature, axis=[1,2], keepdims=True)
assert avg_pool.get_shape()[1:] == (1,1,channel)
avg_pool = tf.layers.dense(inputs=avg_pool,
units=channel//ratio,
activation=tf.nn.relu,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
name='mlp_0',
reuse=None)
assert avg_pool.get_shape()[1:] == (1,1,channel//ratio)
avg_pool = tf.layers.dense(inputs=avg_pool,
units=channel,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
name='mlp_1',
reuse=None)
assert avg_pool.get_shape()[1:] == (1,1,channel)
max_pool = tf.reduce_max(input_feature, axis=[1,2], keepdims=True)
assert max_pool.get_shape()[1:] == (1,1,channel)
max_pool = tf.layers.dense(inputs=max_pool,
units=channel//ratio,
activation=tf.nn.relu,
name='mlp_0',
reuse=True)
assert max_pool.get_shape()[1:] == (1,1,channel//ratio)
max_pool = tf.layers.dense(inputs=max_pool,
units=channel,
name='mlp_1',
reuse=True)
assert max_pool.get_shape()[1:] == (1,1,channel)
scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid')
return input_feature * scale
def spatial_attention(input_feature, name):
kernel_size = 7
kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
with tf.variable_scope(name):
avg_pool = tf.reduce_mean(input_feature, axis=[3], keepdims=True)
assert avg_pool.get_shape()[-1] == 1
max_pool = tf.reduce_max(input_feature, axis=[3], keepdims=True)
assert max_pool.get_shape()[-1] == 1
concat = tf.concat([avg_pool,max_pool], 3)
assert concat.get_shape()[-1] == 2
concat = tf.layers.conv2d(concat,
filters=1,
kernel_size=[kernel_size,kernel_size],
strides=[1,1],
padding="same",
activation=None,
kernel_initializer=kernel_initializer,
use_bias=False,
name='conv')
assert concat.get_shape()[-1] == 1
concat = tf.sigmoid(concat, 'sigmoid')
return input_feature * concat