Skip to content

Commit daf8a1f

Browse files
sachinprasadhsmattdangerw
authored andcommitted
Add DenseNet (#1775)
* Add DenseNet * fix testcase * address comments * nit * fix lint errors * move description
1 parent 2861676 commit daf8a1f

File tree

6 files changed

+469
-0
lines changed

6 files changed

+469
-0
lines changed

keras_nlp/api/models/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@
9696
from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import (
9797
DebertaV3Tokenizer,
9898
)
99+
from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone
100+
from keras_nlp.src.models.densenet.densenet_image_classifier import (
101+
DenseNetImageClassifier,
102+
)
99103
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
100104
DistilBertBackbone,
101105
)
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
16+
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.backbone import Backbone
18+
19+
BN_AXIS = 3
20+
BN_EPSILON = 1.001e-5
21+
22+
23+
@keras_nlp_export("keras_nlp.models.DenseNetBackbone")
24+
class DenseNetBackbone(Backbone):
25+
"""Instantiates the DenseNet architecture.
26+
27+
This class implements a DenseNet backbone as described in
28+
[Densely Connected Convolutional Networks (CVPR 2017)](
29+
https://arxiv.org/abs/1608.06993
30+
).
31+
32+
Args:
33+
stackwise_num_repeats: list of ints, number of repeated convolutional
34+
blocks per dense block.
35+
include_rescaling: bool, whether to rescale the inputs. If set
36+
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
37+
layer. Defaults to `True`.
38+
input_image_shape: optional shape tuple, defaults to (224, 224, 3).
39+
compression_ratio: float, compression rate at transition layers,
40+
defaults to 0.5.
41+
growth_rate: int, number of filters added by each dense block,
42+
defaults to 32
43+
44+
Examples:
45+
```python
46+
input_data = np.ones(shape=(8, 224, 224, 3))
47+
48+
# Pretrained backbone
49+
model = keras_nlp.models.DenseNetBackbone.from_preset("densenet121_imagenet")
50+
model(input_data)
51+
52+
# Randomly initialized backbone with a custom config
53+
model = keras_nlp.models.DenseNetBackbone(
54+
stackwise_num_repeats=[6, 12, 24, 16],
55+
include_rescaling=False,
56+
)
57+
model(input_data)
58+
```
59+
"""
60+
61+
def __init__(
62+
self,
63+
stackwise_num_repeats,
64+
include_rescaling=True,
65+
input_image_shape=(224, 224, 3),
66+
compression_ratio=0.5,
67+
growth_rate=32,
68+
**kwargs,
69+
):
70+
# === Functional Model ===
71+
image_input = keras.layers.Input(shape=input_image_shape)
72+
73+
x = image_input
74+
if include_rescaling:
75+
x = keras.layers.Rescaling(1 / 255.0)(x)
76+
77+
x = keras.layers.Conv2D(
78+
64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv"
79+
)(x)
80+
x = keras.layers.BatchNormalization(
81+
axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn"
82+
)(x)
83+
x = keras.layers.Activation("relu", name="conv1_relu")(x)
84+
x = keras.layers.MaxPooling2D(
85+
3, strides=2, padding="same", name="pool1"
86+
)(x)
87+
88+
for stack_index in range(len(stackwise_num_repeats) - 1):
89+
index = stack_index + 2
90+
x = apply_dense_block(
91+
x,
92+
stackwise_num_repeats[stack_index],
93+
growth_rate,
94+
name=f"conv{index}",
95+
)
96+
x = apply_transition_block(
97+
x, compression_ratio, name=f"pool{index}"
98+
)
99+
100+
x = apply_dense_block(
101+
x,
102+
stackwise_num_repeats[-1],
103+
growth_rate,
104+
name=f"conv{len(stackwise_num_repeats) + 1}",
105+
)
106+
107+
x = keras.layers.BatchNormalization(
108+
axis=BN_AXIS, epsilon=BN_EPSILON, name="bn"
109+
)(x)
110+
x = keras.layers.Activation("relu", name="relu")(x)
111+
112+
super().__init__(inputs=image_input, outputs=x, **kwargs)
113+
114+
# === Config ===
115+
self.stackwise_num_repeats = stackwise_num_repeats
116+
self.include_rescaling = include_rescaling
117+
self.compression_ratio = compression_ratio
118+
self.growth_rate = growth_rate
119+
self.input_image_shape = input_image_shape
120+
121+
def get_config(self):
122+
config = super().get_config()
123+
config.update(
124+
{
125+
"stackwise_num_repeats": self.stackwise_num_repeats,
126+
"include_rescaling": self.include_rescaling,
127+
"compression_ratio": self.compression_ratio,
128+
"growth_rate": self.growth_rate,
129+
"input_image_shape": self.input_image_shape,
130+
}
131+
)
132+
return config
133+
134+
135+
def apply_dense_block(x, num_repeats, growth_rate, name=None):
136+
"""A dense block.
137+
138+
Args:
139+
x: input tensor.
140+
num_repeats: int, number of repeated convolutional blocks.
141+
growth_rate: int, number of filters added by each dense block.
142+
name: string, block label.
143+
"""
144+
if name is None:
145+
name = f"dense_block_{keras.backend.get_uid('dense_block')}"
146+
147+
for i in range(num_repeats):
148+
x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}")
149+
return x
150+
151+
152+
def apply_transition_block(x, compression_ratio, name=None):
153+
"""A transition block.
154+
155+
Args:
156+
x: input tensor.
157+
compression_ratio: float, compression rate at transition layers.
158+
name: string, block label.
159+
"""
160+
if name is None:
161+
name = f"transition_block_{keras.backend.get_uid('transition_block')}"
162+
163+
x = keras.layers.BatchNormalization(
164+
axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn"
165+
)(x)
166+
x = keras.layers.Activation("relu", name=f"{name}_relu")(x)
167+
x = keras.layers.Conv2D(
168+
int(x.shape[BN_AXIS] * compression_ratio),
169+
1,
170+
use_bias=False,
171+
name=f"{name}_conv",
172+
)(x)
173+
x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x)
174+
return x
175+
176+
177+
def apply_conv_block(x, growth_rate, name=None):
178+
"""A building block for a dense block.
179+
180+
Args:
181+
x: input tensor.
182+
growth_rate: int, number of filters added by each dense block.
183+
name: string, block label.
184+
"""
185+
if name is None:
186+
name = f"conv_block_{keras.backend.get_uid('conv_block')}"
187+
188+
shortcut = x
189+
x = keras.layers.BatchNormalization(
190+
axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn"
191+
)(x)
192+
x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x)
193+
x = keras.layers.Conv2D(
194+
4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv"
195+
)(x)
196+
x = keras.layers.BatchNormalization(
197+
axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn"
198+
)(x)
199+
x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x)
200+
x = keras.layers.Conv2D(
201+
growth_rate,
202+
3,
203+
padding="same",
204+
use_bias=False,
205+
name=f"{name}_2_conv",
206+
)(x)
207+
x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")(
208+
[shortcut, x]
209+
)
210+
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone
19+
from keras_nlp.src.tests.test_case import TestCase
20+
21+
22+
class DenseNetBackboneTest(TestCase):
23+
def setUp(self):
24+
self.init_kwargs = {
25+
"stackwise_num_repeats": [6, 12, 24, 16],
26+
"include_rescaling": True,
27+
"compression_ratio": 0.5,
28+
"growth_rate": 32,
29+
"input_image_shape": (224, 224, 3),
30+
}
31+
self.input_data = np.ones((2, 224, 224, 3), dtype="float32")
32+
33+
def test_backbone_basics(self):
34+
self.run_backbone_test(
35+
cls=DenseNetBackbone,
36+
init_kwargs=self.init_kwargs,
37+
input_data=self.input_data,
38+
expected_output_shape=(2, 7, 7, 1024),
39+
run_mixed_precision_check=False,
40+
)
41+
42+
@pytest.mark.large
43+
def test_saved_model(self):
44+
self.run_model_saving_test(
45+
cls=DenseNetBackbone,
46+
init_kwargs=self.init_kwargs,
47+
input_data=self.input_data,
48+
)

0 commit comments

Comments
 (0)