Skip to content

Commit 23815d6

Browse files
authored
Add image and audio converter classes (keras-team#1813)
* Add image and audio converter classes These classes will occupy the same role as tokenizers for text models. They will transform raw inputs to model inputs in a way that is not task specific. * Fix some tests * Input conversion fixes * Torch property fixes * Another fix * Address comments * Add assets on kaggle; bump preset versions * Fix last failing test
1 parent 84a6b66 commit 23815d6

File tree

66 files changed

+930
-644
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+930
-644
lines changed

keras_nlp/api/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from keras_nlp.api import models
2323
from keras_nlp.api import samplers
2424
from keras_nlp.api import tokenizers
25-
from keras_nlp.api import utils
2625
from keras_nlp.src.utils.preset_utils import upload_preset
2726
from keras_nlp.src.version_utils import __version__
2827
from keras_nlp.src.version_utils import version

keras_nlp/api/layers/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
)
3737
from keras_nlp.src.layers.modeling.transformer_decoder import TransformerDecoder
3838
from keras_nlp.src.layers.modeling.transformer_encoder import TransformerEncoder
39+
from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter
40+
from keras_nlp.src.layers.preprocessing.image_converter import ImageConverter
3941
from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import (
4042
MaskedLMMaskGenerator,
4143
)
@@ -44,4 +46,13 @@
4446
)
4547
from keras_nlp.src.layers.preprocessing.random_deletion import RandomDeletion
4648
from keras_nlp.src.layers.preprocessing.random_swap import RandomSwap
49+
from keras_nlp.src.layers.preprocessing.resizing_image_converter import (
50+
ResizingImageConverter,
51+
)
4752
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
53+
from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import (
54+
PaliGemmaImageConverter,
55+
)
56+
from keras_nlp.src.models.whisper.whisper_audio_converter import (
57+
WhisperAudioConverter,
58+
)

keras_nlp/api/models/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,7 @@
228228
from keras_nlp.src.models.text_classifier_preprocessor import (
229229
TextClassifierPreprocessor,
230230
)
231-
from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import (
232-
WhisperAudioFeatureExtractor,
233-
)
234231
from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone
235-
from keras_nlp.src.models.whisper.whisper_preprocessor import (
236-
WhisperPreprocessor,
237-
)
238232
from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer
239233
from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import (
240234
XLMRobertaBackbone,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras_nlp.src.api_export import keras_nlp_export
15+
from keras_nlp.src.layers.preprocessing.preprocessing_layer import (
16+
PreprocessingLayer,
17+
)
18+
from keras_nlp.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
19+
from keras_nlp.src.utils.preset_utils import find_subclass
20+
from keras_nlp.src.utils.preset_utils import get_preset_loader
21+
from keras_nlp.src.utils.preset_utils import list_presets
22+
from keras_nlp.src.utils.preset_utils import list_subclasses
23+
from keras_nlp.src.utils.preset_utils import save_serialized_object
24+
from keras_nlp.src.utils.python_utils import classproperty
25+
26+
27+
@keras_nlp_export("keras_nlp.layers.AudioConverter")
28+
class AudioConverter(PreprocessingLayer):
29+
"""Convert raw audio for models that support audio input.
30+
31+
This class converts from raw audio tensors of any length, to preprocessed
32+
audio for pretrained model inputs. It is meant to be a convenient way to
33+
write custom preprocessing code that is not model specific. This layer
34+
should be instantiated via the `from_preset()` constructor, which will
35+
create the correct subclass of this layer for the model preset.
36+
37+
The layer will take as input a raw audio tensor with shape `(batch_size,
38+
num_samples)`, and output a preprocessed audio input for modeling. The exact
39+
structure of the preprocessed input will vary per model. Preprocessing
40+
will often include computing a spectogram of the raw audio signal.
41+
42+
Examples:
43+
```python
44+
# Load an audio converter from a preset.
45+
converter = keras_nlp.layers.AudioConverter.from_preset("whisper_base_en")
46+
# Convert some raw audio input.
47+
converter(np.ones(2, 1_000))
48+
```
49+
"""
50+
51+
backbone_cls = None
52+
53+
@classproperty
54+
def presets(cls):
55+
"""List built-in presets for a `Task` subclass."""
56+
presets = list_presets(cls)
57+
for subclass in list_subclasses(cls):
58+
presets.update(subclass.presets)
59+
return presets
60+
61+
@classmethod
62+
def from_preset(
63+
cls,
64+
preset,
65+
**kwargs,
66+
):
67+
"""Instantiate a `keras_nlp.layers.AudioConverter` from a model preset.
68+
69+
A preset is a directory of configs, weights and other file assets used
70+
to save and load a pre-trained model. The `preset` can be passed as
71+
one of:
72+
73+
1. a built-in preset identifier like `'whisper_base_en'`
74+
2. a Kaggle Models handle like
75+
`'kaggle://user/whisper/keras/whisper_base_en'`
76+
3. a Hugging Face handle like `'hf://user/whisper_base_en'`
77+
4. a path to a local preset directory like `'./whisper_base_en'`
78+
79+
You can run `cls.presets.keys()` to list all built-in presets available
80+
on the class.
81+
82+
This constructor can be called in one of two ways. Either from the base
83+
class like `keras_nlp.models.AudioConverter.from_preset()`, or from a
84+
model class like `keras_nlp.models.WhisperAudioConverter.from_preset()`.
85+
If calling from the base class, the subclass of the returning object
86+
will be inferred from the config in the preset directory.
87+
88+
Args:
89+
preset: string. A built-in preset identifier, a Kaggle Models
90+
handle, a Hugging Face handle, or a path to a local directory.
91+
load_weights: bool. If `True`, the weights will be loaded into the
92+
model architecture. If `False`, the weights will be randomly
93+
initialized.
94+
95+
Examples:
96+
```python
97+
# Load an audio converter from a preset.
98+
converter = keras_nlp.layers.AudioConverter.from_preset(
99+
"whisper_base_en"
100+
)
101+
# Convert some raw mono channel audio input.
102+
converter(np.ones(2, 1_000))
103+
```
104+
"""
105+
loader = get_preset_loader(preset)
106+
backbone_cls = loader.check_backbone_class()
107+
if cls.backbone_cls != backbone_cls:
108+
cls = find_subclass(preset, cls, backbone_cls)
109+
return loader.load_audio_converter(cls, **kwargs)
110+
111+
def save_to_preset(self, preset_dir):
112+
"""Save audio converter to a preset directory.
113+
114+
Args:
115+
preset_dir: The path to the local model preset directory.
116+
"""
117+
save_serialized_object(
118+
self,
119+
preset_dir,
120+
config_file=AUDIO_CONVERTER_CONFIG_FILE,
121+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pathlib
17+
18+
import numpy as np
19+
import pytest
20+
21+
from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter
22+
from keras_nlp.src.models.backbone import Backbone
23+
from keras_nlp.src.models.whisper.whisper_audio_converter import (
24+
WhisperAudioConverter,
25+
)
26+
from keras_nlp.src.tests.test_case import TestCase
27+
28+
29+
class AudioConverterTest(TestCase):
30+
def test_preset_accessors(self):
31+
pali_gemma_presets = set(WhisperAudioConverter.presets.keys())
32+
all_presets = set(AudioConverter.presets.keys())
33+
self.assertContainsSubset(pali_gemma_presets, all_presets)
34+
35+
@pytest.mark.large
36+
def test_from_preset(self):
37+
self.assertIsInstance(
38+
AudioConverter.from_preset("whisper_tiny_en"),
39+
WhisperAudioConverter,
40+
)
41+
42+
@pytest.mark.large
43+
def test_from_preset_errors(self):
44+
with self.assertRaises(ValueError):
45+
AudioConverter.from_preset("bert_tiny_en_uncased")
46+
with self.assertRaises(ValueError):
47+
# No loading on a non-keras model.
48+
AudioConverter.from_preset("hf://spacy/en_core_web_sm")
49+
50+
@pytest.mark.large
51+
def test_save_to_preset(self):
52+
save_dir = self.get_temp_dir()
53+
converter = AudioConverter.from_preset(
54+
"whisper_tiny_en",
55+
num_mels=40,
56+
)
57+
converter.save_to_preset(save_dir)
58+
# Save a backbone so the preset is valid.
59+
backbone = Backbone.from_preset("whisper_tiny_en", load_weights=False)
60+
backbone.save_to_preset(save_dir)
61+
62+
# Check existence of files.
63+
path = pathlib.Path(save_dir)
64+
self.assertTrue(os.path.exists(path / "audio_converter.json"))
65+
66+
# Check loading.
67+
restored = AudioConverter.from_preset(save_dir)
68+
test_audio = np.random.rand(1_000)
69+
self.assertAllClose(restored(test_audio), converter(test_audio))
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras_nlp.src.api_export import keras_nlp_export
15+
from keras_nlp.src.layers.preprocessing.preprocessing_layer import (
16+
PreprocessingLayer,
17+
)
18+
from keras_nlp.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
19+
from keras_nlp.src.utils.preset_utils import find_subclass
20+
from keras_nlp.src.utils.preset_utils import get_preset_loader
21+
from keras_nlp.src.utils.preset_utils import list_presets
22+
from keras_nlp.src.utils.preset_utils import list_subclasses
23+
from keras_nlp.src.utils.preset_utils import save_serialized_object
24+
from keras_nlp.src.utils.python_utils import classproperty
25+
26+
27+
@keras_nlp_export("keras_nlp.layers.ImageConverter")
28+
class ImageConverter(PreprocessingLayer):
29+
"""Convert raw image for models that support image input.
30+
31+
This class converts from raw images of any size, to preprocessed
32+
images for pretrained model inputs. It is meant to be a convenient way to
33+
write custom preprocessing code that is not model specific. This layer
34+
should be instantiated via the `from_preset()` constructor, which will
35+
create the correct subclass of this layer for the model preset.
36+
37+
The layer will take as input a raw image tensor in the channels last or
38+
channels first format, and output a preprocessed image input for modeling.
39+
The exact structure of the output will vary per model, though in most cases
40+
this layer will simply resize the image to the size needed by the model
41+
input.
42+
43+
Examples:
44+
```python
45+
# Resize images for `"pali_gemma_3b_224"`.
46+
converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_224")
47+
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
48+
# Resize images for `"pali_gemma_3b_448"`.
49+
converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_448")
50+
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3)
51+
```
52+
"""
53+
54+
backbone_cls = None
55+
56+
@classproperty
57+
def presets(cls):
58+
"""List built-in presets for a `Task` subclass."""
59+
presets = list_presets(cls)
60+
for subclass in list_subclasses(cls):
61+
presets.update(subclass.presets)
62+
return presets
63+
64+
@classmethod
65+
def from_preset(
66+
cls,
67+
preset,
68+
**kwargs,
69+
):
70+
"""Instantiate a `keras_nlp.layers.ImageConverter` from a model preset.
71+
72+
A preset is a directory of configs, weights and other file assets used
73+
to save and load a pre-trained model. The `preset` can be passed as
74+
one of:
75+
76+
1. a built-in preset identifier like `'pali_gemma_3b_224'`
77+
2. a Kaggle Models handle like
78+
`'kaggle://user/paligemma/keras/pali_gemma_3b_224'`
79+
3. a Hugging Face handle like `'hf://user/pali_gemma_3b_224'`
80+
4. a path to a local preset directory like `'./pali_gemma_3b_224'`
81+
82+
You can run `cls.presets.keys()` to list all built-in presets available
83+
on the class.
84+
85+
This constructor can be called in one of two ways. Either from the base
86+
class like `keras_nlp.models.ImageConverter.from_preset()`, or from a
87+
model class like
88+
`keras_nlp.models.PaliGemmaImageConverter.from_preset()`. If calling
89+
from the base class, the subclass of the returning object will be
90+
inferred from the config in the preset directory.
91+
92+
Args:
93+
preset: string. A built-in preset identifier, a Kaggle Models
94+
handle, a Hugging Face handle, or a path to a local directory.
95+
load_weights: bool. If `True`, the weights will be loaded into the
96+
model architecture. If `False`, the weights will be randomly
97+
initialized.
98+
99+
Examples:
100+
```python
101+
# Resize images for `"pali_gemma_3b_224"`.
102+
converter = keras_nlp.layers.ImageConverter.from_preset(
103+
"pali_gemma_3b_224"
104+
)
105+
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
106+
# Override arguments on the base class.
107+
converter = keras_nlp.layers.ImageConverter.from_preset(
108+
"pali_gemma_3b_448",
109+
crop_to_aspect_ratio=False,
110+
)
111+
converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3)
112+
```
113+
"""
114+
loader = get_preset_loader(preset)
115+
backbone_cls = loader.check_backbone_class()
116+
if cls.backbone_cls != backbone_cls:
117+
cls = find_subclass(preset, cls, backbone_cls)
118+
return loader.load_image_converter(cls, **kwargs)
119+
120+
def save_to_preset(self, preset_dir):
121+
"""Save image converter to a preset directory.
122+
123+
Args:
124+
preset_dir: The path to the local model preset directory.
125+
"""
126+
save_serialized_object(
127+
self,
128+
preset_dir,
129+
config_file=IMAGE_CONVERTER_CONFIG_FILE,
130+
)

0 commit comments

Comments
 (0)