Skip to content

Commit 2861676

Browse files
james77777778mattdangerw
authored andcommitted
Add FeaturePyramidBackbone and port weights from timm for ResNetBackbone (#1769)
* Add FeaturePyramidBackbone and update ResNetBackbone * Simplify the implementation * Fix CI * Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone * Add conversion implementation * Update docstrings * Address comments
1 parent e8bef25 commit 2861676

12 files changed

+529
-103
lines changed

keras_nlp/api/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
FalconCausalLMPreprocessor,
145145
)
146146
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
147+
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
147148
from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
148149
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
149150
from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
16+
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.backbone import Backbone
18+
19+
20+
@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone")
21+
class FeaturePyramidBackbone(Backbone):
22+
"""A backbone with feature pyramid outputs.
23+
24+
`FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs`
25+
property for accessing the feature pyramid outputs of the model. Subclassers
26+
should set the `pyramid_outputs` property during the model constructor.
27+
28+
Example:
29+
30+
```python
31+
input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3))
32+
33+
# Convert to feature pyramid output format using ResNet.
34+
backbone = ResNetBackbone.from_preset("resnet50")
35+
model = keras.Model(
36+
inputs=backbone.inputs, outputs=backbone.pyramid_outputs
37+
)
38+
model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"]
39+
```
40+
"""
41+
42+
@property
43+
def pyramid_outputs(self):
44+
"""A dict for feature pyramid outputs.
45+
46+
The key is a string represents the name of the feature output and the
47+
value is a `keras.KerasTensor`. A typical feature pyramid has multiple
48+
levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale
49+
`Pn` represents a feature map `2^n` times smaller in width and height
50+
than the inputs.
51+
"""
52+
return getattr(self, "_pyramid_outputs", {})
53+
54+
@pyramid_outputs.setter
55+
def pyramid_outputs(self, value):
56+
if not isinstance(value, dict):
57+
raise TypeError(
58+
"`pyramid_outputs` must be a dictionary. "
59+
f"Received: value={value} of type {type(value)}"
60+
)
61+
for k, v in value.items():
62+
if not isinstance(k, str):
63+
raise TypeError(
64+
"The key of `pyramid_outputs` must be a string. "
65+
f"Received: key={k} of type {type(k)}"
66+
)
67+
if not isinstance(v, keras.KerasTensor):
68+
raise TypeError(
69+
"The value of `pyramid_outputs` must be a "
70+
"`keras.KerasTensor`. "
71+
f"Received: value={v} of type {type(v)}"
72+
)
73+
self._pyramid_outputs = value

0 commit comments

Comments
 (0)