Skip to content

Commit 9b1cf95

Browse files
authored
Expunge include_rescaling from backbones (#1859)
Since our models include built in preprocessing, it is much clearer for this rescaling to happen in the preprocessing layers.
1 parent d39db7b commit 9b1cf95

32 files changed

+154
-157
lines changed

keras_hub/src/layers/preprocessing/resizing_image_converter.py

+56-6
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import keras
15+
from keras import ops
1516

1617
from keras_hub.src.api_export import keras_hub_export
1718
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
19+
from keras_hub.src.utils.keras_utils import standardize_data_format
1820
from keras_hub.src.utils.tensor_utils import preprocessing_function
1921

2022

@@ -23,13 +25,23 @@ class ResizingImageConverter(ImageConverter):
2325
"""An `ImageConverter` that simply resizes the input image.
2426
2527
The `ResizingImageConverter` is a subclass of `ImageConverter` for models
26-
that simply need to resize image tensors before using them for modeling.
27-
The layer will take as input a raw image tensor (batched or unbatched) in the
28-
channels last or channels first format, and output a resize tensor.
28+
that need to resize (and optionally rescale) image tensors before using them
29+
for modeling. The layer will take as input a raw image tensor (batched or
30+
unbatched) in the channels last or channels first format, and output a
31+
resize tensor.
2932
3033
Args:
31-
height: Integer, the height of the output shape.
32-
width: Integer, the width of the output shape.
34+
height: int, the height of the output shape.
35+
width: int, the width of the output shape.
36+
scale: float or `None`. If set, the image we be rescaled with a
37+
`keras.layers.Rescaling` layer, multiplying the image by this
38+
scale.
39+
mean: tuples of floats per channel or `None`. If set, the image will be
40+
normalized per channel by subtracting mean.
41+
If set, also set `variance`.
42+
variance: tuples of floats per channel or `None`. If set, the image will
43+
be normalized per channel by dividing by `sqrt(variance)`.
44+
If set, also set `mean`.
3345
crop_to_aspect_ratio: If `True`, resize the images without aspect
3446
ratio distortion. When the original aspect ratio differs
3547
from the target aspect ratio, the output image will be
@@ -64,6 +76,9 @@ def __init__(
6476
self,
6577
height,
6678
width,
79+
scale=None,
80+
mean=None,
81+
variance=None,
6782
crop_to_aspect_ratio=True,
6883
interpolation="bilinear",
6984
data_format=None,
@@ -78,15 +93,47 @@ def __init__(
7893
crop_to_aspect_ratio=crop_to_aspect_ratio,
7994
interpolation=interpolation,
8095
data_format=data_format,
96+
dtype=self.dtype_policy,
97+
name="resizing",
8198
)
99+
if scale is not None:
100+
self.rescaling = keras.layers.Rescaling(
101+
scale=scale,
102+
dtype=self.dtype_policy,
103+
name="rescaling",
104+
)
105+
else:
106+
self.rescaling = None
107+
if (mean is not None) != (variance is not None):
108+
raise ValueError(
109+
"Both `mean` and `variance` should be set or `None`. Received "
110+
f"`mean={mean}`, `variance={variance}`."
111+
)
112+
self.scale = scale
113+
self.mean = mean
114+
self.variance = variance
115+
self.data_format = standardize_data_format(data_format)
82116

83117
def image_size(self):
84118
"""Returns the preprocessed size of a single image."""
85119
return (self.resizing.height, self.resizing.width)
86120

87121
@preprocessing_function
88122
def call(self, inputs):
89-
return self.resizing(inputs)
123+
x = self.resizing(inputs)
124+
if self.rescaling:
125+
x = self.rescaling(x)
126+
if self.mean is not None:
127+
# Avoid `layers.Normalization` so this works batched and unbatched.
128+
channels_first = self.data_format == "channels_first"
129+
if len(ops.shape(inputs)) == 3:
130+
broadcast_dims = (1, 2) if channels_first else (0, 1)
131+
else:
132+
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
133+
mean = ops.expand_dims(ops.array(self.mean), broadcast_dims)
134+
std = ops.expand_dims(ops.sqrt(self.variance), broadcast_dims)
135+
x = (x - mean) / std
136+
return x
90137

91138
def get_config(self):
92139
config = super().get_config()
@@ -96,6 +143,9 @@ def get_config(self):
96143
"width": self.resizing.width,
97144
"interpolation": self.resizing.interpolation,
98145
"crop_to_aspect_ratio": self.resizing.crop_to_aspect_ratio,
146+
"scale": self.scale,
147+
"mean": self.mean,
148+
"variance": self.variance,
99149
}
100150
)
101151
return config

keras_hub/src/layers/preprocessing/resizing_image_converter_test.py

+43-8
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,57 @@
2222

2323

2424
class ResizingImageConverterTest(TestCase):
25+
def test_resize_simple(self):
26+
converter = ResizingImageConverter(height=4, width=4)
27+
inputs = np.ones((10, 10, 3))
28+
outputs = converter(inputs)
29+
self.assertAllClose(outputs, ops.ones((4, 4, 3)))
30+
2531
def test_resize_one(self):
26-
converter = ResizingImageConverter(22, 22)
27-
test_image = np.random.rand(10, 10, 3) * 255
28-
shape = ops.shape(converter(test_image))
29-
self.assertEqual(shape, (22, 22, 3))
32+
converter = ResizingImageConverter(
33+
height=4,
34+
width=4,
35+
mean=(0.5, 0.7, 0.3),
36+
variance=(0.25, 0.1, 0.5),
37+
scale=1 / 255.0,
38+
)
39+
inputs = np.ones((10, 10, 3)) * 128
40+
outputs = converter(inputs)
41+
self.assertEqual(ops.shape(outputs), (4, 4, 3))
42+
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.003922)
43+
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * -0.626255)
44+
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.285616)
3045

3146
def test_resize_batch(self):
32-
converter = ResizingImageConverter(12, 12)
33-
test_batch = np.random.rand(4, 10, 20, 3) * 255
34-
shape = ops.shape(converter(test_batch))
35-
self.assertEqual(shape, (4, 12, 12, 3))
47+
converter = ResizingImageConverter(
48+
height=4,
49+
width=4,
50+
mean=(0.5, 0.7, 0.3),
51+
variance=(0.25, 0.1, 0.5),
52+
scale=1 / 255.0,
53+
)
54+
inputs = np.ones((2, 10, 10, 3)) * 128
55+
outputs = converter(inputs)
56+
self.assertEqual(ops.shape(outputs), (2, 4, 4, 3))
57+
self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.003922)
58+
self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * -0.626255)
59+
self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.285616)
60+
61+
def test_errors(self):
62+
with self.assertRaises(ValueError):
63+
ResizingImageConverter(
64+
height=4,
65+
width=4,
66+
mean=(0.5, 0.7, 0.3),
67+
)
3668

3769
def test_config(self):
3870
converter = ResizingImageConverter(
3971
width=12,
4072
height=20,
73+
mean=(0.5, 0.7, 0.3),
74+
variance=(0.25, 0.1, 0.5),
75+
scale=1 / 255.0,
4176
crop_to_aspect_ratio=False,
4277
interpolation="nearest",
4378
)

keras_hub/src/models/csp_darknet/csp_darknet_backbone.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone):
3131
level in the model.
3232
stackwise_depth: A list of ints, the depth for each dark level in the
3333
model.
34-
include_rescaling: boolean. If `True`, rescale the input using
35-
`Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to
36-
`True`.
3734
block_type: str. One of `"basic_block"` or `"depthwise_block"`.
3835
Use `"depthwise_block"` for depthwise conv block
3936
`"basic_block"` for basic conv block.
@@ -55,7 +52,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone):
5552
model = keras_hub.models.CSPDarkNetBackbone(
5653
stackwise_num_filters=[128, 256, 512, 1024],
5754
stackwise_depth=[3, 9, 9, 3],
58-
include_rescaling=False,
5955
)
6056
model(input_data)
6157
```
@@ -65,7 +61,6 @@ def __init__(
6561
self,
6662
stackwise_num_filters,
6763
stackwise_depth,
68-
include_rescaling=True,
6964
block_type="basic_block",
7065
image_shape=(None, None, 3),
7166
**kwargs,
@@ -82,10 +77,7 @@ def __init__(
8277
base_channels = stackwise_num_filters[0] // 2
8378

8479
image_input = layers.Input(shape=image_shape)
85-
x = image_input
86-
if include_rescaling:
87-
x = layers.Rescaling(scale=1 / 255.0)(x)
88-
80+
x = image_input # Intermediate result.
8981
x = apply_focus(channel_axis, name="stem_focus")(x)
9082
x = apply_darknet_conv_block(
9183
base_channels,
@@ -130,7 +122,6 @@ def __init__(
130122
# === Config ===
131123
self.stackwise_num_filters = stackwise_num_filters
132124
self.stackwise_depth = stackwise_depth
133-
self.include_rescaling = include_rescaling
134125
self.block_type = block_type
135126
self.image_shape = image_shape
136127
self.pyramid_outputs = pyramid_outputs
@@ -141,7 +132,6 @@ def get_config(self):
141132
{
142133
"stackwise_num_filters": self.stackwise_num_filters,
143134
"stackwise_depth": self.stackwise_depth,
144-
"include_rescaling": self.include_rescaling,
145135
"block_type": self.block_type,
146136
"image_shape": self.image_shape,
147137
}

keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py

-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class CSPDarkNetImageClassifier(ImageClassifier):
7676
backbone = keras_hub.models.CSPDarkNetBackbone(
7777
stackwise_num_filters=[128, 256, 512, 1024],
7878
stackwise_depth=[3, 9, 9, 3],
79-
include_rescaling=False,
8079
block_type="basic_block",
8180
image_shape = (224, 224, 3),
8281
)

keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def setUp(self):
3131
self.backbone = CSPDarkNetBackbone(
3232
stackwise_num_filters=[2, 16, 16],
3333
stackwise_depth=[1, 3, 3, 1],
34-
include_rescaling=False,
3534
block_type="basic_block",
3635
image_shape=(16, 16, 3),
3736
)

keras_hub/src/models/densenet/densenet_backbone.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
3131
Args:
3232
stackwise_num_repeats: list of ints, number of repeated convolutional
3333
blocks per dense block.
34-
include_rescaling: bool, whether to rescale the inputs. If set
35-
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
36-
layer. Defaults to `True`.
3734
image_shape: optional shape tuple, defaults to (None, None, 3).
3835
compression_ratio: float, compression rate at transition layers,
3936
defaults to 0.5.
@@ -51,7 +48,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
5148
# Randomly initialized backbone with a custom config
5249
model = keras_hub.models.DenseNetBackbone(
5350
stackwise_num_repeats=[6, 12, 24, 16],
54-
include_rescaling=False,
5551
)
5652
model(input_data)
5753
```
@@ -60,7 +56,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
6056
def __init__(
6157
self,
6258
stackwise_num_repeats,
63-
include_rescaling=True,
6459
image_shape=(None, None, 3),
6560
compression_ratio=0.5,
6661
growth_rate=32,
@@ -71,10 +66,7 @@ def __init__(
7166
channel_axis = -1 if data_format == "channels_last" else 1
7267
image_input = keras.layers.Input(shape=image_shape)
7368

74-
x = image_input
75-
if include_rescaling:
76-
x = keras.layers.Rescaling(1 / 255.0)(x)
77-
69+
x = image_input # Intermediate result.
7870
x = keras.layers.Conv2D(
7971
64,
8072
7,
@@ -124,7 +116,6 @@ def __init__(
124116

125117
# === Config ===
126118
self.stackwise_num_repeats = stackwise_num_repeats
127-
self.include_rescaling = include_rescaling
128119
self.compression_ratio = compression_ratio
129120
self.growth_rate = growth_rate
130121
self.image_shape = image_shape
@@ -135,7 +126,6 @@ def get_config(self):
135126
config.update(
136127
{
137128
"stackwise_num_repeats": self.stackwise_num_repeats,
138-
"include_rescaling": self.include_rescaling,
139129
"compression_ratio": self.compression_ratio,
140130
"growth_rate": self.growth_rate,
141131
"image_shape": self.image_shape,

keras_hub/src/models/densenet/densenet_image_classifier.py

-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ class DenseNetImageClassifier(ImageClassifier):
7474
backbone = keras_hub.models.DenseNetBackbone(
7575
stackwise_num_filters=[128, 256, 512, 1024],
7676
stackwise_depth=[3, 9, 9, 3],
77-
include_rescaling=False,
7877
block_type="basic_block",
7978
image_shape = (224, 224, 3),
8079
)

keras_hub/src/models/densenet/densenet_image_classifier_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def setUp(self):
2828
self.labels = [0, 3]
2929
self.backbone = DenseNetBackbone(
3030
stackwise_num_repeats=[6, 12, 24, 16],
31-
include_rescaling=True,
3231
compression_ratio=0.5,
3332
growth_rate=32,
3433
image_shape=(224, 224, 3),

keras_hub/src/models/efficientnet/efficientnet_backbone.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
6767
MBConvBlock, but instead of using a depthwise convolution and a 1x1
6868
output convolution blocks fused blocks use a single 3x3 convolution
6969
block.
70-
include_rescaling: bool, whether to rescale the inputs. If set to
71-
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
7270
min_depth: integer, minimum number of filters. Can be None and ignored
7371
if use_depth_divisor_as_min_depth is set to True.
7472
include_initial_padding: bool, whether to include initial zero padding
@@ -96,7 +94,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
9694
stackwise_block_types=[["fused"] * 3 + ["unfused"] * 3],
9795
width_coefficient=1.0,
9896
depth_coefficient=1.0,
99-
include_rescaling=False,
10097
)
10198
images = np.ones((1, 256, 256, 3))
10299
outputs = efficientnet.predict(images)
@@ -116,7 +113,6 @@ def __init__(
116113
stackwise_squeeze_and_excite_ratios,
117114
stackwise_strides,
118115
stackwise_block_types,
119-
include_rescaling=True,
120116
dropout=0.2,
121117
depth_divisor=8,
122118
min_depth=8,
@@ -129,14 +125,9 @@ def __init__(
129125
batch_norm_momentum=0.9,
130126
**kwargs,
131127
):
132-
img_input = keras.layers.Input(shape=input_shape)
133-
134-
x = img_input
135-
136-
if include_rescaling:
137-
# Use common rescaling strategy across keras
138-
x = keras.layers.Rescaling(scale=1.0 / 255.0)(x)
128+
image_input = keras.layers.Input(shape=input_shape)
139129

130+
x = image_input # Intermediate result.
140131
if include_initial_padding:
141132
x = keras.layers.ZeroPadding2D(
142133
padding=self._correct_pad_downsample(x, 3),
@@ -282,10 +273,9 @@ def __init__(
282273
curr_pyramid_level += 1
283274

284275
# Create model.
285-
super().__init__(inputs=img_input, outputs=x, **kwargs)
276+
super().__init__(inputs=image_input, outputs=x, **kwargs)
286277

287278
# === Config ===
288-
self.include_rescaling = include_rescaling
289279
self.width_coefficient = width_coefficient
290280
self.depth_coefficient = depth_coefficient
291281
self.dropout = dropout
@@ -313,7 +303,6 @@ def get_config(self):
313303
config = super().get_config()
314304
config.update(
315305
{
316-
"include_rescaling": self.include_rescaling,
317306
"width_coefficient": self.width_coefficient,
318307
"depth_coefficient": self.depth_coefficient,
319308
"dropout": self.dropout,

0 commit comments

Comments
 (0)