12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import keras
15
+ from keras import ops
15
16
16
17
from keras_hub .src .api_export import keras_hub_export
17
18
from keras_hub .src .layers .preprocessing .image_converter import ImageConverter
19
+ from keras_hub .src .utils .keras_utils import standardize_data_format
18
20
from keras_hub .src .utils .tensor_utils import preprocessing_function
19
21
20
22
@@ -23,13 +25,23 @@ class ResizingImageConverter(ImageConverter):
23
25
"""An `ImageConverter` that simply resizes the input image.
24
26
25
27
The `ResizingImageConverter` is a subclass of `ImageConverter` for models
26
- that simply need to resize image tensors before using them for modeling.
27
- The layer will take as input a raw image tensor (batched or unbatched) in the
28
- channels last or channels first format, and output a resize tensor.
28
+ that need to resize (and optionally rescale) image tensors before using them
29
+ for modeling. The layer will take as input a raw image tensor (batched or
30
+ unbatched) in the channels last or channels first format, and output a
31
+ resize tensor.
29
32
30
33
Args:
31
- height: Integer, the height of the output shape.
32
- width: Integer, the width of the output shape.
34
+ height: int, the height of the output shape.
35
+ width: int, the width of the output shape.
36
+ scale: float or `None`. If set, the image we be rescaled with a
37
+ `keras.layers.Rescaling` layer, multiplying the image by this
38
+ scale.
39
+ mean: tuples of floats per channel or `None`. If set, the image will be
40
+ normalized per channel by subtracting mean.
41
+ If set, also set `variance`.
42
+ variance: tuples of floats per channel or `None`. If set, the image will
43
+ be normalized per channel by dividing by `sqrt(variance)`.
44
+ If set, also set `mean`.
33
45
crop_to_aspect_ratio: If `True`, resize the images without aspect
34
46
ratio distortion. When the original aspect ratio differs
35
47
from the target aspect ratio, the output image will be
@@ -64,6 +76,9 @@ def __init__(
64
76
self ,
65
77
height ,
66
78
width ,
79
+ scale = None ,
80
+ mean = None ,
81
+ variance = None ,
67
82
crop_to_aspect_ratio = True ,
68
83
interpolation = "bilinear" ,
69
84
data_format = None ,
@@ -78,15 +93,47 @@ def __init__(
78
93
crop_to_aspect_ratio = crop_to_aspect_ratio ,
79
94
interpolation = interpolation ,
80
95
data_format = data_format ,
96
+ dtype = self .dtype_policy ,
97
+ name = "resizing" ,
81
98
)
99
+ if scale is not None :
100
+ self .rescaling = keras .layers .Rescaling (
101
+ scale = scale ,
102
+ dtype = self .dtype_policy ,
103
+ name = "rescaling" ,
104
+ )
105
+ else :
106
+ self .rescaling = None
107
+ if (mean is not None ) != (variance is not None ):
108
+ raise ValueError (
109
+ "Both `mean` and `variance` should be set or `None`. Received "
110
+ f"`mean={ mean } `, `variance={ variance } `."
111
+ )
112
+ self .scale = scale
113
+ self .mean = mean
114
+ self .variance = variance
115
+ self .data_format = standardize_data_format (data_format )
82
116
83
117
def image_size (self ):
84
118
"""Returns the preprocessed size of a single image."""
85
119
return (self .resizing .height , self .resizing .width )
86
120
87
121
@preprocessing_function
88
122
def call (self , inputs ):
89
- return self .resizing (inputs )
123
+ x = self .resizing (inputs )
124
+ if self .rescaling :
125
+ x = self .rescaling (x )
126
+ if self .mean is not None :
127
+ # Avoid `layers.Normalization` so this works batched and unbatched.
128
+ channels_first = self .data_format == "channels_first"
129
+ if len (ops .shape (inputs )) == 3 :
130
+ broadcast_dims = (1 , 2 ) if channels_first else (0 , 1 )
131
+ else :
132
+ broadcast_dims = (0 , 2 , 3 ) if channels_first else (0 , 1 , 2 )
133
+ mean = ops .expand_dims (ops .array (self .mean ), broadcast_dims )
134
+ std = ops .expand_dims (ops .sqrt (self .variance ), broadcast_dims )
135
+ x = (x - mean ) / std
136
+ return x
90
137
91
138
def get_config (self ):
92
139
config = super ().get_config ()
@@ -96,6 +143,9 @@ def get_config(self):
96
143
"width" : self .resizing .width ,
97
144
"interpolation" : self .resizing .interpolation ,
98
145
"crop_to_aspect_ratio" : self .resizing .crop_to_aspect_ratio ,
146
+ "scale" : self .scale ,
147
+ "mean" : self .mean ,
148
+ "variance" : self .variance ,
99
149
}
100
150
)
101
151
return config
0 commit comments