Skip to content

Commit e49361b

Browse files
WindQAQfacaiy
authored andcommitted
Speed up mean_filter2d with depthwise_conv2d (#235)
* speed up mean_filter2d with depthwise_conv2d * cast the output back to the original dtype * refactor test cases * avoid loss of precision * add test case with channels of None * add doc of _tile_image * use ones instead of random data * add test case with 4x4 filter * add doc related to padding
1 parent be2a68d commit e49361b

File tree

3 files changed

+244
-128
lines changed

3 files changed

+244
-128
lines changed

tensorflow_addons/image/filters.py

+93-64
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
2122

2223

2324
@tf.function
@@ -34,82 +35,110 @@ def func2():
3435
return tf.cond(tf.math.greater(ma, one), func2, func1)
3536

3637

37-
@tf.function
38-
def mean_filter2d(image, filter_shape=(3, 3), name=None):
39-
"""This method performs Mean Filtering on image. Filter shape can be user
40-
given.
38+
def _pad(image, filter_shape, mode="CONSTANT", constant_values=0):
39+
"""Explicitly pad a 4-D image.
40+
41+
Equivalent to the implicit padding method offered in `tf.nn.conv2d` and
42+
`tf.nn.depthwise_conv2d`, but supports non-zero, reflect and symmetric
43+
padding mode. For the even-sized filter, it pads one more value to the
44+
right or the bottom side.
4145
42-
This method takes both kind of images where pixel values lie between 0 to
43-
255 and where it lies between 0.0 and 1.0
4446
Args:
45-
image: A 3D `Tensor` of type `float32` or 'int32' or 'float64' or
46-
'int64 and of shape`[rows, columns, channels]`
47+
image: A 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
48+
filter_shape: A `tuple`/`list` of 2 integers, specifying the height
49+
and width of the 2-D filter.
50+
mode: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC".
51+
The type of padding algorithm to use, which is compatible with
52+
`mode` argument in `tf.pad`. For more details, please refer to
53+
https://www.tensorflow.org/api_docs/python/tf/pad.
54+
constant_values: A `scalar`, the pad value to use in "CONSTANT"
55+
padding mode.
56+
"""
57+
assert mode in ["CONSTANT", "REFLECT", "SYMMETRIC"]
58+
filter_height, filter_width = filter_shape
59+
pad_top = (filter_height - 1) // 2
60+
pad_bottom = filter_height - 1 - pad_top
61+
pad_left = (filter_width - 1) // 2
62+
pad_right = filter_width - 1 - pad_left
63+
paddings = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
64+
return tf.pad(image, paddings, mode=mode, constant_values=constant_values)
4765

48-
filter_shape: Optional Argument. A tuple of 2 integers (R,C).
49-
R is the first value is the number of rows in the filter and
50-
C is the second value in the filter is the number of columns
51-
in the filter. This creates a filter of shape (R,C) or RxC
52-
filter. Default value = (3,3)
5366

54-
Returns:
55-
A 3D mean filtered image tensor of shape [rows,columns,channels] and
56-
type 'int32'. Pixel value of returned tensor ranges between 0 to 255
57-
"""
67+
@tf.function
68+
def mean_filter2d(image,
69+
filter_shape=(3, 3),
70+
padding="REFLECT",
71+
constant_values=0,
72+
name=None):
73+
"""Perform mean filtering on image(s).
5874
75+
Args:
76+
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
77+
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
78+
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
79+
the height and width of the 2-D mean filter. Can be a single integer
80+
to specify the same value for all spatial dimensions.
81+
padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC".
82+
The type of padding algorithm to use, which is compatible with
83+
`mode` argument in `tf.pad`. For more details, please refer to
84+
https://www.tensorflow.org/api_docs/python/tf/pad.
85+
constant_values: A `scalar`, the pad value to use in "CONSTANT"
86+
padding mode.
87+
name: A name for this operation (optional).
88+
Returns:
89+
3-D or 4-D `Tensor` of the same dtype as input.
90+
Raises:
91+
ValueError: If `image` is not 3 or 4-dimensional,
92+
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
93+
or if `filter_shape` is invalid.
94+
"""
5995
with tf.name_scope(name or "mean_filter2d"):
60-
if not isinstance(filter_shape, tuple):
61-
raise TypeError('Filter shape must be a tuple')
62-
if len(filter_shape) != 2:
63-
raise ValueError('Filter shape must be a tuple of 2 integers. '
64-
'Got %s values in tuple' % len(filter_shape))
65-
filter_shapex = filter_shape[0]
66-
filter_shapey = filter_shape[1]
67-
if not isinstance(filter_shapex, int) or not isinstance(
68-
filter_shapey, int):
69-
raise TypeError('Size of the filter must be Integers')
70-
(row, col, ch) = (image.shape[0], image.shape[1], image.shape[2])
71-
if row != None and col != None and ch != None:
72-
(row, col, ch) = (int(row), int(col), int(ch))
73-
else:
74-
raise TypeError(
75-
'All the Dimensions of the input image tensor must be \
76-
Integers.')
77-
if row < filter_shapex or col < filter_shapey:
96+
image = tf.convert_to_tensor(image, name="image")
97+
98+
rank = image.shape.rank
99+
if rank != 3 and rank != 4:
100+
raise ValueError("image should be either 3 or 4-dimensional.")
101+
102+
if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
78103
raise ValueError(
79-
'Number of Pixels in each dimension of the image should be \
80-
more than the filter size. Got filter_shape (%sx' %
81-
filter_shape[0] + '%s).' % filter_shape[1] +
82-
' Image Shape (%s)' % image.shape)
83-
if filter_shapex % 2 == 0 or filter_shapey % 2 == 0:
84-
raise ValueError('Filter size should be odd. Got filter_shape (%sx'
85-
% filter_shape[0] + '%s)' % filter_shape[1])
86-
image = tf.cast(image, tf.float32)
87-
tf_i = tf.reshape(image, [row * col * ch])
88-
ma = tf.math.reduce_max(tf_i)
89-
image = _normalize(image, ma)
104+
"padding should be one of \"REFLECT\", \"CONSTANT\", or "
105+
"\"SYMMETRIC\".")
90106

91-
# k and l is the Zero-padding size
107+
filter_shape = keras_utils.conv_utils.normalize_tuple(
108+
filter_shape, 2, "filter_shape")
92109

93-
listi = []
94-
for a in range(ch):
95-
img = image[:, :, a:a + 1]
96-
img = tf.reshape(img, [1, row, col, 1])
97-
slic = tf.image.extract_patches(
98-
img, [1, filter_shapex, filter_shapey, 1], [1, 1, 1, 1],
99-
[1, 1, 1, 1],
100-
padding='SAME')
101-
li = tf.reduce_mean(slic, axis=-1)
102-
li = tf.reshape(li, [row, col, 1])
103-
listi.append(li)
104-
y = tf.concat(listi[0], 2)
110+
# Expand to a 4-D tensor
111+
if rank == 3:
112+
image = tf.expand_dims(image, axis=0)
105113

106-
for i in range(len(listi) - 1):
107-
y = tf.concat([y, listi[i + 1]], 2)
114+
# Keep the precision if it's float;
115+
# otherwise, convert to float32 for computing.
116+
orig_dtype = image.dtype
117+
if not image.dtype.is_floating:
118+
image = tf.dtypes.cast(image, tf.dtypes.float32)
108119

109-
y *= 255
110-
y = tf.cast(y, tf.int32)
120+
# Explicitly pad the image
121+
image = _pad(
122+
image, filter_shape, mode=padding, constant_values=constant_values)
111123

112-
return y
124+
# Filter of shape (filter_width, filter_height, in_channels, 1)
125+
# has the value of 1 for each element.
126+
area = tf.constant(
127+
filter_shape[0] * filter_shape[1], dtype=image.dtype)
128+
filter_shape = filter_shape + (tf.shape(image)[-1], 1)
129+
kernel = tf.ones(shape=filter_shape, dtype=image.dtype)
130+
131+
output = tf.nn.depthwise_conv2d(
132+
image, kernel, strides=(1, 1, 1, 1), padding="VALID")
133+
134+
output /= area
135+
136+
# Squeeze out the first axis to make sure
137+
# output has the same dimension with image.
138+
if rank == 3:
139+
output = tf.squeeze(output, axis=0)
140+
141+
return tf.dtypes.cast(output, orig_dtype)
113142

114143

115144
@tf.function

0 commit comments

Comments
 (0)