Skip to content

Commit a9c31cb

Browse files
committed
Simplify the implementation
1 parent 4a705a1 commit a9c31cb

File tree

5 files changed

+38
-64
lines changed

5 files changed

+38
-64
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@
106106
)
107107
from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor
108108
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
109-
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
110109
from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
111110
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
112111
from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import (

keras_nlp/src/models/backbone.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,39 @@ def token_embedding(self):
108108
def token_embedding(self, value):
109109
self._token_embedding = value
110110

111+
@property
112+
def pyramid_outputs(self):
113+
"""A dict for feature pyramid outputs.
114+
115+
The key is a string represents the name of the feature output and the
116+
value is a `keras.KerasTensor`. A typical feature pyramid has multiple
117+
levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale
118+
`Pn` represents a feature map `2^n` times smaller in width and height
119+
than the inputs.
120+
"""
121+
return getattr(self, "_pyramid_outputs", {})
122+
123+
@pyramid_outputs.setter
124+
def pyramid_outputs(self, value):
125+
if not isinstance(value, dict):
126+
raise TypeError(
127+
"`pyramid_outputs` must be a dictionary. "
128+
f"Received: value={value} of type {type(value)}"
129+
)
130+
for k, v in value.items():
131+
if not isinstance(k, str):
132+
raise TypeError(
133+
"The key of `pyramid_outputs` must be a string. "
134+
f"Received: key={k} of type {type(k)}"
135+
)
136+
if not isinstance(v, keras.KerasTensor):
137+
raise TypeError(
138+
"The value of `pyramid_outputs` must be a "
139+
"`keras.KerasTensor`. "
140+
f"Received: value={v} of type {type(v)}"
141+
)
142+
self._pyramid_outputs = value
143+
111144
def quantize(self, mode, **kwargs):
112145
assert_quantization_support()
113146
return super().quantize(mode, **kwargs)

keras_nlp/src/models/feature_pyramid_backbone.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

keras_nlp/src/models/resnet/resnet_backbone.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from keras import layers
1616

1717
from keras_nlp.src.api_export import keras_nlp_export
18-
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
19-
from keras_nlp.src.utils.keras_utils import get_tensor_name
18+
from keras_nlp.src.models.backbone import Backbone
2019
from keras_nlp.src.utils.keras_utils import standardize_data_format
2120

2221

2322
@keras_nlp_export("keras_nlp.models.ResNetBackbone")
24-
class ResNetBackbone(FeaturePyramidBackbone):
23+
class ResNetBackbone(Backbone):
2524
"""ResNet and ResNetV2 core network with hyperparameters.
2625
2726
This class implements a ResNet backbone as described in [Deep Residual
@@ -181,7 +180,7 @@ def __init__(
181180
dtype=dtype,
182181
name=f"{version}_stack{stack_index}",
183182
)
184-
pyramid_outputs[f"P{stack_index + 2}"] = get_tensor_name(x)
183+
pyramid_outputs[f"P{stack_index + 2}"] = x
185184

186185
if use_pre_activation:
187186
x = layers.BatchNormalization(

keras_nlp/src/models/resnet/resnet_feature_pyramid_backbone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ def __init__(
128128
outputs = {}
129129
for k in output_keys:
130130
try:
131-
name = self.pyramid_outputs[k]
131+
output = self.pyramid_outputs[k]
132132
except KeyError:
133133
raise KeyError(
134134
f"'{k}' not in self.pyramid_outputs. The available keys "
135135
f"are: {list(self.pyramid_outputs.keys())}"
136136
)
137-
outputs[k] = self.get_layer(name).output
137+
outputs[k] = output
138138

139139
super(ResNetBackbone, self).__init__(
140140
inputs=self.inputs, outputs=outputs, dtype=dtype, **kwargs

0 commit comments

Comments
 (0)