diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index f587a1f304..8ddb1c1d33 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion else pytest --cache-clear --check_gpu --run_large --durations 0 \ @@ -83,5 +84,6 @@ else keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index d1935a0577..1b2f5c9e3e 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -11,27 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -def get_initializer(initializer_range=0.02): - """ - Creates a `keras.initializers.TruncatedNormal` with the given range. - - Args: - initializer_range (*float*, defaults to 0.02): Standard deviation of the - initializer range. - - Returns: - `keras.initializers.TruncatedNormal`: The truncated normal initializer. - """ - return keras.initializers.TruncatedNormal(stddev=initializer_range) - - @keras_cv_export("keras_cv.models.feature_extractor.QuickGELU") class QuickGELU(keras.layers.Layer): def __init__(self, **kwargs): @@ -54,13 +38,6 @@ def __init__( self.proj_dim = proj_dim self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers - self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02 - - self.in_proj_std = ( - np.power(self.proj_dim, -0.5) - * (np.power(2 * self.num_hidden_layers, -0.5)) - * 0.02 - ) self.attn = CLIPAttention( self.proj_dim, self.num_heads, @@ -156,9 +133,14 @@ def __init__(self, width, num_layers, heads, **kwargs): ] def build(self, input_shape): - super().build(input_shape) for block in self.resblocks: block.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + for block in self.resblocks: + input_shape = block.compute_output_shape(input_shape) + return input_shape def call( self, @@ -174,9 +156,6 @@ def call( ) return x - def compute_output_shape(self, inputs_shape): - return inputs_shape - def get_config(self): config = super().get_config() config.update( @@ -213,30 +192,20 @@ def __init__( ) self.scale = self.head_dim**-0.5 - in_proj_std = ( - (self.proj_dim**-0.5) - * ((2 * self.num_hidden_layers) ** -0.5) - * 0.02 - ) - out_proj_std = (self.proj_dim**-0.5) * 0.02 self.q_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="q_proj", ) self.k_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="k_proj", ) self.v_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="v_proj", ) self.out_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(out_proj_std), name="out_proj", ) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index 663b01cd62..fb4cf71ef4 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -16,10 +16,8 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder -from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer -@keras_cv_export("keras_cv.models.feature_extractor.CLIPPatchingAndEmbedding") class CLIPPatchingAndEmbedding(keras.layers.Layer): def __init__( self, width, patch_size, input_resolution, output_dim, **kwargs @@ -33,7 +31,6 @@ def __init__( padding="valid", use_bias=False, data_format="channels_last", - kernel_initializer=get_initializer(0.02), name="patch_embed.embedding", ) self.width = width @@ -42,9 +39,6 @@ def __init__( self.num_patches = ops.power( (self.input_resolution // self.patch_size), 2 ) - self.class_embedding_initializer = get_initializer( - ops.power(self.width, -0.5) * 0.02 - ) self.output_dim = output_dim def build(self, input_shape): @@ -52,7 +46,6 @@ def build(self, input_shape): self.conv1.build(input_shape) self.class_embedding = self.add_weight( shape=((self.width,)), - initializer=self.class_embedding_initializer, name="patch_embed.class_embedding", ) @@ -67,6 +60,13 @@ def build(self, input_shape): name="patch_embed.positional_embedding", ) + def compute_output_shape(self, input_shape): + return [ + None, + (self.input_resolution // self.patch_size) ** 2 + 1, + self.width, + ] + def call(self, x): batch_size = ops.shape(x)[0] patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel] @@ -143,12 +143,15 @@ def __init__( ) def build(self, input_shape): - super().build(input_shape) self.embeddings.build(input_shape) self.pre_norm.build([None, None, self.width]) self.encoder.build(None) self.post_norm.build([None, self.width]) - self.image_projector.build([None, None, self.width]) + self.image_projector.build([None, self.width]) + self.built = True + + def compute_output_shape(self, input_shape): + return [input_shape[0], self.output_dim] def call(self, image): x = self.embeddings(image) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 860739388e..bc5e29097f 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -34,6 +34,41 @@ keras_nlp = None +class CLIPHead(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_shape): + self.logit_scale = self.add_variable( + shape=(), + initializer=lambda *a, **kw: ops.log(1 / 0.07), + trainable=True, + dtype=self.variable_dtype, + name="logit_scale", + ) + self.built = True + + def call(self, image_embeddings, text_embeddings): + normalize_image_features = ops.sqrt( + ops.sum(ops.power(image_embeddings, 2), keepdims=True) + ) + normalize_text_features = ops.sqrt( + ops.sum(ops.power(text_embeddings, 2), keepdims=True) + ) + image_embeddings = image_embeddings / normalize_image_features + text_embeddings = text_embeddings / normalize_text_features + logit_scale = ops.exp(self.logit_scale) + image_logits = ( + ops.matmul( + image_embeddings, + ops.transpose(text_embeddings), + ) + * logit_scale + ) + text_logits = ops.transpose(image_logits) + return image_logits, text_logits + + @keras_cv_export(["keras_cv.models.CLIP"]) class CLIP(Task): """ @@ -61,25 +96,27 @@ class CLIP(Task): transformer-based text encoder. transformer_layers (int): The number of layers in the transformer-based text encoder. + Example: + ```python processor = CLIPProcessor( - input_resolution=224, - "path_to_vocab.json", - "path_to_merges.txt" - ) + input_resolution=224, + "path_to_vocab.json", + "path_to_merges.txt" + ) processed_image = processor.process_images(["cat.jpg"]) - processed_text, attention_mask = processor.process_texts( - ["mountains", "cat on tortoise", "two cats"] - ) + tokens = processor( + ["mountains", "cat on tortoise", "two cats"] + ) model = CLIP.from_preset("clip-vit-base-patch16") image_logits, text_logits = model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) + { + "images": processed_image, + "token_ids": tokens["token_ids"], + "padding_mask": tokens["padding_mask"], + } + ) ``` """ @@ -97,12 +134,70 @@ def __init__( transformer_layers=12, **kwargs, ): - super().__init__(**kwargs) if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " "using pip `pip install -U keras-nlp && pip install -U keras`" ) + + vision_heads = vision_width // 64 + + images = keras.Input( + shape=[image_resolution, image_resolution, 3], name="images" + ) + token_ids = keras.Input( + shape=[ + context_length, + ], + name="token_ids", + ) + padding_mask = keras.Input( + shape=[ + context_length, + ], + name="padding_mask", + ) + + image_encoder = CLIPImageEncoder( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + num_layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + name="image_encoder", + ) + text_encoder = CLIPTextEncoder( + transformer_width=transformer_width, + transformer_layers=transformer_layers, + transformer_heads=transformer_heads, + vocab_size=vocab_size, + embed_dim=embed_dim, + context_length=context_length, + name="text_encoder", + ) + clip_head = CLIPHead(name="clip_head") + + image_embeddings = image_encoder(images) + text_embeddings = text_encoder(token_ids, attention_mask=padding_mask) + image_logits, text_logits = clip_head(image_embeddings, text_embeddings) + + inputs = { + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask, + } + outputs = { + "image_logits": image_logits, + "text_logits": text_logits, + } + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + self.embed_dim = embed_dim self.image_resolution = image_resolution self.vision_layers = vision_layers @@ -113,75 +208,9 @@ def __init__( self.transformer_width = transformer_width self.transformer_heads = transformer_heads self.transformer_layers = transformer_layers - - vision_heads = self.vision_width // 64 - self.image_encoder = CLIPImageEncoder( - input_resolution=self.image_resolution, - patch_size=self.vision_patch_size, - width=self.vision_width, - num_layers=self.vision_layers, - heads=vision_heads, - output_dim=self.embed_dim, - name="image_encoder", - ) - self.text_encoder = CLIPTextEncoder( - transformer_width=self.transformer_width, - transformer_layers=self.transformer_layers, - transformer_heads=self.transformer_heads, - vocab_size=self.vocab_size, - embed_dim=self.embed_dim, - context_length=self.context_length, - name="text_encoder", - ) - - self.logit_scale = keras.Variable( - ops.ones([]) * ops.log(1 / 0.07), name="logit_scale" - ) - self.image_embeddings = None - self.text_embeddings = None - - def build(self, input_shape): - super().build(input_shape) - self.text_encoder.build([None, self.context_length]) - self.image_encoder.build( - [None, self.image_resolution, self.image_resolution, 3] - ) - - def encode_images(self, image): - return self.image_encoder(image) - - def encode_text(self, text, attention_mask=None): - return self.text_encoder(text, attention_mask=attention_mask) - - def call(self, inputs): - image, text = inputs["image"], inputs["text"] - if "attention_mask" in inputs: - attention_mask = inputs["attention_mask"] - else: - attention_mask = None - self.image_embeddings = self.encode_images(image) - self.text_embeddings = self.encode_text( - text, attention_mask=attention_mask - ) - normalize_image_features = ops.sqrt( - ops.sum(ops.power(self.image_embeddings, 2), keepdims=True) - ) - normalize_text_features = ops.sqrt( - ops.sum(ops.power(self.text_embeddings, 2), keepdims=True) - ) - self.image_embeddings = self.image_embeddings / normalize_image_features - self.text_embeddings = self.text_embeddings / normalize_text_features - logit_scale = ops.exp(self.logit_scale) - logits_per_image = ( - ops.matmul( - self.image_embeddings, - ops.transpose(self.text_embeddings), - ) - * logit_scale - ) - logits_per_text = ops.transpose(logits_per_image) - - return logits_per_image, logits_per_text + self.image_encoder = image_encoder + self.text_encoder = text_encoder + self.clip_head = clip_head @classproperty def presets(cls): diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index a5541303c9..f9126aa9f2 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -35,105 +35,89 @@ ) +@pytest.mark.skipif( + not keras_3(), + reason="Only works with Keras 3", +) class CLIPTest(TestCase): + @pytest.mark.large def test_clip_model_golden_values(self): model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - image_logits, text_logits = model( + outputs = model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose(image_logits, [[1.896712, 1.896712, 1.896712]]) + + # These values are NOT computing using HF as the reference model. + # Currently, the numerics of the CLIP model don't match the + # HF model exactly (for the same inputs). For the time being, + # these tests just confirm that unrelated changed don't affect + # the numerics. Once the fix for the numerics is in, we can remove + # this comment and the xfail below. + self.assertAllClose( + outputs["image_logits"], [[10.246354, 10.246353, 10.246354]] + ) self.assertAllClose( - text_logits, ops.transpose([[1.896712, 1.896712, 1.896712]]) + outputs["text_logits"], + ops.transpose([[10.246354, 10.246353, 10.246354]]), ) + # True reference values computed using HF: + # image_logits: [[17.8013, 17.8013, 17.8013]] + # text_logits: image_logits.T + + # xfail after assertion + pytest.xfail("KerasCV CLIP doesn't match the HF model.") + def test_clip_preprocessor(self): - processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) - processed_text, attention_mask = processor.process_texts( - ["mountains", "cat on tortoise"] - ) + processor = CLIPProcessor(VOCAB_PATH, MERGE_PATH) + tokens = processor(["mountains", "cat on tortoise"]) self.assertAllClose( - processed_text[:, :3], [[49406, 5873, 49407], [49406, 2368, 525]] + tokens["token_ids"][:, :3], + [[49406, 5873, 49407], [49406, 2368, 525]], ) self.assertAllClose( - attention_mask[0, :5], [True, True, True, False, False] + tokens["padding_mask"][0, :5], [True, True, True, False, False] ) def test_clip_preprocessor_tf_data(self): - processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + processor = CLIPProcessor(VOCAB_PATH, MERGE_PATH) text_input = ["a bus", "a dog", "a cat"] dataset = tf_data.Dataset.from_tensor_slices(text_input) - dataset.map(processor.process_texts) + dataset.map(processor) @pytest.mark.large def test_presets(self): - # self.skipTest("TODO: Enable after Kaggle model is public") - model = CLIP.from_preset("clip-vit-base-patch16") - processed_image = np.ones(shape=[1, 224, 224, 3]) - processed_text = np.ones(shape=[3, 77]) - attention_mask = np.ones(shape=[3, 77]) - image_logits, text_logits = model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) - - @pytest.mark.large - def test_image_encoder_golden_values(self): model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose( - model.image_embeddings[:, :5], - [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], - ) - - @pytest.mark.large - def test_text_encoder_golden_values(self): - model = CLIP() - processed_image = np.ones(shape=[1, 224, 224, 3]) - processed_text = np.ones(shape=[3, 77]) - attention_mask = np.ones(shape=[3, 77]) - model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) - self.assertAllClose( - model.text_embeddings[0, :3], - [0.007531, -0.038361, -0.035686], - ) @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self): - model = CLIP() + model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - model_output, _ = model( + outputs = model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) save_path = os.path.join(self.get_temp_dir(), "model.keras") @@ -146,11 +130,11 @@ def test_saved_model(self): # Check we got the real object back. self.assertIsInstance(restored_model, CLIP) # Check that output matches. - restored_output, _ = restored_model( + restored_outputs = restored_model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose(model_output, restored_output) + self.assertAllClose(outputs, restored_outputs) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 656c9ad8ed..f148aa2a28 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -28,7 +28,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/7", }, "clip-vit-base-patch32": { "metadata": { @@ -44,7 +44,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/6", }, "clip-vit-large-patch14": { "metadata": { @@ -60,7 +60,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/6", }, "clip-vit-large-patch14-336": { "metadata": { @@ -76,6 +76,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/4", # noqa: E501 + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/6", # noqa: E501 }, } diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 16d8d24222..7714294769 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -12,9 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tensorflow as tf +import tree + from keras_cv.api_export import keras_cv_export +from keras_cv.backend import config from keras_cv.backend import keras -from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_processor_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_cv.models.feature_extractor.clip.clip_processor_utils import ( + convert_to_backend_tensor_or_python_list, +) from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer try: @@ -25,10 +34,10 @@ @keras_cv_export("keras_cv.models.feature_extractor.CLIPProcessor") -class CLIPProcessor: +class CLIPProcessor(keras.layers.Layer): """ CLIPProcessor is a utility class that provides functionality for processing - images and texts in the context of the CLIP (Contrastive Language-Image + texts in the context of the CLIP (Contrastive Language-Image Pretraining) model. Args: @@ -39,31 +48,27 @@ class CLIPProcessor: should be the file path to merge rules. The merge rule file should have one merge rule per line. - Methods: - process_images(image_path: List[str]): Transforms an image located at - the specified path. - - process_texts(texts: Union[str, List[str]], context_length: int = 77): - Processes a single text or a list of texts, returning packed token - sequences. - """ - def __init__(self, input_resolution, vocabulary, merges, **kwargs): + def __init__(self, vocabulary, merges, **kwargs): + super().__init__(**kwargs) if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " "using pip `pip install -U keras-nlp && pip install -U keras`" ) - self.input_resolution = input_resolution self.vocabulary = vocabulary self.merges = merges - self.image_transform = self.transform_image self.tokenizer = CLIPTokenizer( vocabulary=self.vocabulary, merges=self.merges, - unsplittable_tokens=[""], ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. self.packer = StartEndPacker( start_value=self.tokenizer.token_to_id("<|startoftext|>"), end_value=self.tokenizer.token_to_id("<|endoftext|>"), @@ -71,69 +76,58 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs): sequence_length=77, return_padding_mask=True, ) + self.built = True - def transform_image(self, image_path): - input_resolution = self.input_resolution - mean = ops.array([0.48145466, 0.4578275, 0.40821073]) - std = ops.array([0.26862954, 0.26130258, 0.27577711]) - - image = keras.utils.load_img(image_path) - image = keras.utils.img_to_array(image) - image = ( - ops.image.resize( - image, - (input_resolution, input_resolution), - interpolation="bicubic", - ) - / 255.0 - ) - central_fraction = input_resolution / image.shape[0] - width, height = image.shape[0], image.shape[1] - left = ops.cast((width - width * central_fraction) / 2, dtype="int32") - top = ops.cast((height - height * central_fraction) / 2, dtype="int32") - right = ops.cast((width + width * central_fraction) / 2, dtype="int32") - bottom = ops.cast( - (height + height * central_fraction) / 2, dtype="int32" - ) + def _process_texts(self, texts, context_length: int = 77): + # Ensure the layer is built + if not self.built: + self.build(None) - image = ops.slice( - image, [left, top, 0], [right - left, bottom - top, 3] - ) + texts = convert_inputs_to_list_of_tensor_segments(texts) - image = (image - mean) / std - return image + if len(texts) != 1: + raise ValueError( + "CLIP requires each input feature to contain only " + f"one segment, but received {len(texts)}." + ) - def process_images(self, images): - if isinstance(images, str): - images = [images] + token_ids, padding_mask = self.packer( + self.tokenizer(texts[0]), + sequence_length=context_length, + add_start_value=True, + add_end_value=True, + ) + return {"token_ids": token_ids, "padding_mask": padding_mask} - def process_image(image): - if isinstance(image, str): - return self.image_transform(image) + def call(self, texts, context_length: int = 77): + return self._process_texts(texts, context_length=context_length) - processed_images = list(map(process_image, images)) - processed_images = ops.stack(processed_images) - return processed_images + def get_build_config(self): + return None - def process_texts(self, texts, context_length: int = 77): - if isinstance(texts, str): - texts = [texts] + def __call__(self, *args, **kwargs): + # Always place on CPU for preprocessing, to avoid expensive back and + # forth copies to GPU before the trainable model. + with tf.device("cpu"): + outputs = super().__call__(*args, **kwargs) - def pack_tokens(text): - return self.packer( - self.tokenizer(text), - sequence_length=context_length, - add_start_value=True, - add_end_value=True, - ) + # Jax and Torch lack native string and ragged types. + # If we are running on those backends and not running with tf.data + # (we are outside a tf.function), we covert all ragged and string + # tensor to pythonic types. + is_tf_backend = config.backend() == "tensorflow" + is_in_tf_graph = not tf.executing_eagerly() + if not is_tf_backend and not is_in_tf_graph: + outputs = tree.map_structure( + convert_to_backend_tensor_or_python_list, outputs + ) - return pack_tokens(texts) + return outputs def get_config(self): config = super().get_config() config.update( { - "input_resolution": self.input_resolution, "vocabulary": self.vocabulary, "merges": self.merges, } diff --git a/keras_cv/models/feature_extractor/clip/clip_processor_utils.py b/keras_cv/models/feature_extractor/clip/clip_processor_utils.py new file mode 100644 index 0000000000..919cb316d0 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_processor_utils.py @@ -0,0 +1,110 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + +from keras_cv.backend import ops + + +def _decode_strings_to_utf8(inputs): + """Recursively decodes to list of strings with 'utf-8' encoding.""" + if isinstance(inputs, bytes): + # Handles the case when the input is a scalar string. + return inputs.decode("utf-8", errors="ignore") + else: + # Recursively iterate when input is a list. + return [_decode_strings_to_utf8(x) for x in inputs] + + +def tensor_to_list(inputs): + """Converts a tensor to nested lists. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + """ + if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): + inputs = tf.convert_to_tensor(inputs) + if isinstance(inputs, tf.RaggedTensor): + list_outputs = inputs.to_list() + elif isinstance(inputs, tf.Tensor): + list_outputs = inputs.numpy() + if inputs.shape.rank != 0: + list_outputs = list_outputs.tolist() + if inputs.dtype == tf.string: + list_outputs = _decode_strings_to_utf8(list_outputs) + return list_outputs + + +def convert_to_backend_tensor_or_python_list(x): + """ + Convert a tensor to the backend friendly representation of the data. + + This wraps `ops.convert_to_tensor` to account for the fact that torch and + jax both lack native types for ragged and string data. + + If we encounter one of these types in torch or jax, we will instead covert + the tensor to simple pythonic types (lists of strings). + """ + if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string: + return tensor_to_list(x) + return ops.convert_to_tensor(x) + + +def convert_inputs_to_list_of_tensor_segments(x): + """Converts user inputs to a list of a tensor segments. + + For models and layers which accept lists of string tensors to pack together, + this method converts user inputs to a uniform format in a way that can be + considered canonical for the library. + + We handle the following: + + - A single string will be converted to a tensor and wrapped in a list. + - A list of strings will be converted to a tensor and wrapped in a list. + - A single tensor will be wrapped in a list. + - A list of tensors will be passed through unaltered. + + All other inputs will result in an error. This effectively means that users + who would like to pack multiple segments together should convert those + segments to tensors before calling the layer. This removes any ambiguity + in the input for those cases. + """ + # Check the input type. + is_string = isinstance(x, (str, bytes)) + is_tensor = hasattr(x, "__array__") + is_string_list = ( + isinstance(x, (list, tuple)) and x and isinstance(x[0], (str, bytes)) + ) + is_tensor_list = ( + isinstance(x, (list, tuple)) and x and hasattr(x[0], "__array__") + ) + + if is_string or is_string_list: + # Automatically convert raw strings or string lists to tensors. + # Wrap this input as a single (possibly batched) segment. + x = [tf.convert_to_tensor(x)] + elif is_tensor: + # Automatically wrap a single tensor as a single segment. + x = [x] + elif is_tensor_list: + # Pass lists of tensors though unaltered. + x = x + else: + # Error for all other input. + raise ValueError( + f"Unsupported input for `x`. `x` should be a string, a list of " + "strings, or a list of tensors. If passing multiple segments " + "which should packed together, please convert your inputs to a " + f"list of tensors. Received `x={x}`" + ) + return x diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 26ab92cb0a..1259761d58 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -62,12 +62,15 @@ def __init__( ) def build(self, input_shape): - super().build(input_shape) self.token_embedding.build(input_shape) self.positional_embedding.build([1, self.context_length]) self.encoder.build(None) self.ln_final.build([None, None, self.transformer_width]) - self.text_projector.build([None, None, self.transformer_width]) + self.text_projector.build([None, self.transformer_width]) + self.built = True + + def compute_output_shape(self, input_shape): + return [input_shape[0], self.embed_dim] def call(self, inputs, attention_mask=None): token_embedding = self.token_embedding(inputs) @@ -76,7 +79,7 @@ def call(self, inputs, attention_mask=None): ) position_embedding = self.positional_embedding(position_ids) position_embedding = ops.tile( - position_embedding, repeats=(inputs.shape[0], 1, 1) + position_embedding, repeats=(ops.shape(inputs)[0], 1, 1) ) causal_attention_mask = ops.ones( (self.context_length, self.context_length) diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb index ff3bb4c991..b8f556343e 100644 --- a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -3,150 +3,64 @@ { "cell_type": "markdown", "metadata": { - "id": "0DhV6hzOMY0W" + "id": "mdGT8Em4Mc4b" }, "source": [ - "# Setup" + "# Import" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cRzYR-oFgxt1", - "outputId": "80b8db20-da09-43bd-9b70-fad93b1e1ca1" + "id": "0mtj1abS2cVf" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m465.2/465.2 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m36.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0m" - ] - } - ], + "outputs": [], "source": [ - "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", - "!pip install -q keras-nlp\n", - "!pip install -q tf-keras\n", - "!pip install -q tensorflow-text\n", - "!pip install -q keras==3.0.2" + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nuFgha2jTshi", - "outputId": "63d4160e-42b3-4f6b-e672-ba30c9402d25" + "id": "GDvJmQuug4-x" }, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "--2024-02-21 20:54:06-- https://i.imgur.com/8H7XCH0.jpg\n", - "Resolving i.imgur.com (i.imgur.com)... 146.75.76.193\n", - "Connecting to i.imgur.com (i.imgur.com)|146.75.76.193|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 44544 (44K) [image/jpeg]\n", - "Saving to: ‘cat.jpg’\n", - "\n", - "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", - "\n", - "2024-02-21 20:54:06 (4.16 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", - "\n", - "--2024-02-21 20:54:06-- http://images.cocodataset.org/val2017/000000039769.jpg\n", - "Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.206.137, 16.182.42.89, 54.231.201.177, ...\n", - "Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.206.137|:80... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 173131 (169K) [image/jpeg]\n", - "Saving to: ‘two_cats.jpg’\n", - "\n", - "two_cats.jpg 100%[===================>] 169.07K --.-KB/s in 0.09s \n", - "\n", - "2024-02-21 20:54:07 (1.77 MB/s) - ‘two_cats.jpg’ saved [173131/173131]\n", - "\n", - "--2024-02-21 20:54:07-- https://i.imgur.com/PpgZzP4.jpeg\n", - "Resolving i.imgur.com (i.imgur.com)... 146.75.76.193\n", - "Connecting to i.imgur.com (i.imgur.com)|146.75.76.193|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1610285 (1.5M) [image/jpeg]\n", - "Saving to: ‘mountain.jpg’\n", - "\n", - "mountain.jpg 100%[===================>] 1.54M --.-KB/s in 0.06s \n", - "\n", - "2024-02-21 20:54:07 (27.6 MB/s) - ‘mountain.jpg’ saved [1610285/1610285]\n", - "\n" + "2024-04-08 22:25:42.921695: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-04-08 22:25:42.928541: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-04-08 22:25:43.041846: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-04-08 22:25:44.645915: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ - "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", - "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O two_cats.jpg\n", - "!wget https://i.imgur.com/PpgZzP4.jpeg -O mountain.jpg" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mdGT8Em4Mc4b" - }, - "source": [ - "# Import" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0mtj1abS2cVf" - }, - "outputs": [], - "source": [ - "import os\n", + "import json\n", + "from datetime import datetime\n", "\n", - "os.environ[\"KERAS_BACKEND\"] = \"torch\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GDvJmQuug4-x" - }, - "outputs": [], - "source": [ - "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", "import keras\n", + "from keras import ops\n", + "\n", + "import keras_cv\n", + "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", "from keras_cv.models import CLIP" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "X3kkmK6h_gFH" }, "outputs": [], "source": [ - "# @title Select which model weights you would like to convert\n", "MODEL_CONFIGS = {\n", " \"CLIP_B32\": {\n", " \"embed_dim\": 512,\n", @@ -203,7 +117,7 @@ " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", "}\n", - "config_name = \"CLIP_L14_336\" # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", + "config_name = \"CLIP_B16\"\n", "config_name_hf = model_map_hf[config_name]" ] }, @@ -218,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "urhuhwq0Dczo" }, @@ -250,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -276,170 +190,51 @@ { "data": { "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (CLIPImageEncoder) │ ? │ 0 (unbuilt) │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (CLIPTextEncoder) │ ? │ 0 (unbuilt) │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Total params: 1 (4.00 B)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1\u001b[0m (4.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Trainable params: 1 (4.00 B)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1\u001b[0m (4.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Non-trainable params: 0 (0.00 B)\n", - "\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "buXKlNfGTenW" - }, - "outputs": [], - "source": [ - "processor = CLIPProcessor(\n", - " MODEL_CONFIGS[config_name][\"image_resolution\"], \"vocab.json\", \"merges.txt\"\n", - ")\n", - "image = processor.process_images([\"two_cats.jpg\"])\n", - "text_input = [\"mountains\", \"cat on tortoise\", \"two cats\"]\n", - "text, attention_mask = processor.process_texts(text_input)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BHSpMv0PT5SX" - }, - "outputs": [], - "source": [ - "image_logits, text_logits = model(image, text, attention_mask)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JPn0gACJjKy5", - "outputId": "cbc7313a-4ddd-4021-9e84-fa668987849d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[3.7318, 3.7792, 3.7633]], grad_fn=
Model: \"clip\"\n",
- "\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"clip\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (CLIPImageEncoder) │ ? │ 304,293,888 │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (CLIPTextEncoder) │ ? │ 123,650,304 │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n", + "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (InputLayer) │ (None, 224, 224, │ 0 │ - │\n", + "│ │ 3) │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ token_ids │ (None, 77) │ 0 │ - │\n", + "│ (InputLayer) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ padding_mask │ (None, 77) │ 0 │ - │\n", + "│ (InputLayer) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ image_encoder │ (None, 512) │ 86,192,640 │ images[0][0] │\n", + "│ (CLIPImageEncoder) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ text_encoder │ (None, 512) │ 63,428,096 │ token_ids[0][0], │\n", + "│ (CLIPTextEncoder) │ │ │ padding_mask[0][… │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ clip_head │ [(None, None), │ 1 │ image_encoder[0]… │\n", + "│ (CLIPHead) │ (None, None)] │ │ text_encoder[0][… │\n", + "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n", "\n" ], "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m304,293,888\u001b[0m │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m123,650,304\u001b[0m │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ token_ids │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m77\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ padding_mask │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m77\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ image_encoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m86,192,640\u001b[0m │ images[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ text_encoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m63,428,096\u001b[0m │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ │ │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ clip_head │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m1\u001b[0m │ image_encoder[\u001b[38;5;34m0\u001b[0m]… │\n", + "│ (\u001b[38;5;33mCLIPHead\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ text_encoder[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" ] }, "metadata": {}, @@ -448,11 +243,11 @@ { "data": { "text/html": [ - "Total params: 427,944,193 (1.59 GB)\n", + "Total params: 149,620,737 (570.76 MB)\n", "\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m427,944,193\u001b[0m (1.59 GB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m149,620,737\u001b[0m (570.76 MB)\n" ] }, "metadata": {}, @@ -461,11 +256,11 @@ { "data": { "text/html": [ - "Trainable params: 427,944,193 (1.59 GB)\n", + "Trainable params: 149,620,737 (570.76 MB)\n", "\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m427,944,193\u001b[0m (1.59 GB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m149,620,737\u001b[0m (570.76 MB)\n" ] }, "metadata": {}, @@ -500,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "3W2prd6C0pxe" }, @@ -508,6 +303,7 @@ "source": [ "from PIL import Image\n", "import requests\n", + "import torch\n", "\n", "from transformers import CLIPProcessor as CP\n", "from transformers import CLIPModel as CM" @@ -515,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -614,192 +410,12 @@ "id": "EntuvOq1MhwU", "outputId": "cbd7cd77-6d8f-4a76-dae0-24530c12eeb6" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", - "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", - "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", - "You will be able to reuse this secret in all of your notebooks.\n", - "Please note that authentication is recommended but still optional to access public models or datasets.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "46636db47838400cb7407fc2ab0720eb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "config.json: 0%| | 0.00/4.76k [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f359ba8ef0cf4841b40acafcd770480c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "pytorch_model.bin: 0%| | 0.00/1.71G [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3af19b8b653c4b21a65f7e96dd463aac", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "preprocessor_config.json: 0%| | 0.00/316 [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5c21455d9faa4112ba6f18819f7ef038", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "tokenizer_config.json: 0%| | 0.00/844 [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "418143d2ad92458094259dfca0a747cc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "vocab.json: 0%| | 0.00/862k [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c06b5a6588eb42189210d1c20ccba87a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "merges.txt: 0%| | 0.00/525k [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "79020bd42626472a85bf9047d014830f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "tokenizer.json: 0%| | 0.00/2.22M [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ec7bc6e82f2042b8b29a6f21e6db1709", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "special_tokens_map.json: 0%| | 0.00/389 [00:00, ?B/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "model_hf = CM.from_pretrained(config_name_hf)\n", "processor_hf = CP.from_pretrained(config_name_hf)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ep8DRTkv3AwS", - "outputId": "6e3e802c-3db6-48ac-e3ab-4f52416449a8" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[11.7865, 31.2010, 11.9718]], grad_fn=)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "photo = {\n", - " \"cat\": \"https://i.imgur.com/8H7XCH0.jpg\",\n", - " \"two_cats\": \"http://images.cocodataset.org/val2017/000000039769.jpg\",\n", - " \"mountain\": \"https://i.imgur.com/PpgZzP4.jpeg\",\n", - "}\n", - "url = photo[\"cat\"]\n", - "image_hf = Image.open(requests.get(url, stream=True).raw)\n", - "text_inputs = [\"mountains\", \"cat on tortoise\", \"two dogs\"]\n", - "inputs = processor_hf(\n", - " text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True\n", - ")\n", - "outputs = model_hf(**inputs)\n", - "logits_per_image = (\n", - " outputs.logits_per_image\n", - ") # this is the image-text similarity score\n", - "probs = logits_per_image.softmax(\n", - " dim=1\n", - ") # we can take the softmax to get the label probabilitiesprobs\n", - "logits_per_image" - ] - }, { "cell_type": "markdown", "metadata": { @@ -811,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "id": "wPa0cVnY3cBC" }, @@ -827,18 +443,20 @@ "id": "TUCpKltRG4Gd" }, "source": [ - "##vision encoder" + "## Vision Encoder" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "id": "tn_U02N7U2VN" }, "outputs": [], "source": [ - "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", + "model.get_layer(\"clip_head\").logit_scale.assign(\n", + " hf_wts.pop(\"logit_scale\").numpy()\n", + ")\n", "model.get_layer(\"image_encoder\").get_layer(\n", " \"clip_patch_embedding\"\n", ").class_embedding.assign(\n", @@ -875,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "YRXC2HZC3FjG" }, @@ -988,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "_1AD7TcbdWEC" }, @@ -1013,7 +631,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "id": "IQFquy9R75G8" }, @@ -1104,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1119,7 +737,7 @@ "odict_keys([])" ] }, - "execution_count": 17, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1135,19 +753,178 @@ "id": "wlfDdO-mid62" }, "source": [ - "# save weights" + "# Save Weights" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "id": "QscCUUZFiqBV" }, "outputs": [], "source": [ - "model.save_weights(\"model.weights.h5\")" + "os.makedirs(config_name, exist_ok=True)\n", + "model.save_weights(os.path.join(config_name, \"model.weights.h5\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"module\": \"keras_cv.models.feature_extractor.clip.clip_model\",\n", + " \"class_name\": \"CLIP\",\n", + " \"config\": model.get_config(),\n", + " \"registered_name\": \"keras_cv>CLIP\",\n", + " \"weights\": \"model.weights.h5\",\n", + "}\n", + "\n", + "with open(os.path.join(config_name, \"config.json\"), \"w\") as config_file:\n", + " json.dump(config, config_file)\n", + "\n", + "metadata = {\n", + " \"keras_version\": keras.__version__,\n", + " \"keras_cv_version\": keras_cv.__version__,\n", + " \"parameter_count\": model.count_params(),\n", + " \"date_saved\": datetime.utcnow().strftime(\"%Y-%m-%d@%H:%M:%S\"),\n", + "}\n", + "\n", + "with open(os.path.join(config_name, \"metadata.json\"), \"w\") as metadata_file:\n", + " json.dump(metadata, metadata_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Verify numerics" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n", + "# image = Image.open(requests.get(url, stream=True).raw)\n", + "\n", + "# inputs = processor_hf(text=[\"a photo of a cat\", \"a photo of a dog\"], images=image, return_tensors=\"pt\", padding=True)\n", + "\n", + "# outputs = model_hf(**inputs)\n", + "# logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n", + "# probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# plt.imshow(image)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# probs" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# VOCAB_PATH = keras.utils.get_file(None, \"https://storage.googleapis.com/keras-cv/models/clip/vocab.json\")\n", + "# MERGE_PATH = keras.utils.get_file(None, \"https://storage.googleapis.com/keras-cv/models/clip/merges.txt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH)\n", + "# text_processed = processor([\"a photo of a cat\", \"a photo of a dog\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# image_processed = ops.convert_to_tensor(inputs['pixel_values'].detach().cpu().permute(0, 2, 3, 1).numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# outputs = model({\n", + "# \"images\": image_processed,\n", + "# **text_processed\n", + "# })" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# ops.softmax(outputs[\"image_logits\"], axis=1) # we can take the softmax to get the label probabilities" ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# model.load_weights(\"model.weights.h5\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# outputs = model({\n", + "# \"images\": image_processed,\n", + "# **text_processed\n", + "# })" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "# ops.softmax(outputs[\"image_logits\"], axis=1) # we can take the softmax to get the label probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1158,11 +935,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -3906,5 +3693,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 }