diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index f587a1f304..8ddb1c1d33 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion else pytest --cache-clear --check_gpu --run_large --durations 0 \ @@ -83,5 +84,6 @@ else keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ + keras_cv/models/feature_extractor/clip \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/clip/clip_encoder.py b/keras_cv/models/feature_extractor/clip/clip_encoder.py index d1935a0577..1b2f5c9e3e 100644 --- a/keras_cv/models/feature_extractor/clip/clip_encoder.py +++ b/keras_cv/models/feature_extractor/clip/clip_encoder.py @@ -11,27 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -def get_initializer(initializer_range=0.02): - """ - Creates a `keras.initializers.TruncatedNormal` with the given range. - - Args: - initializer_range (*float*, defaults to 0.02): Standard deviation of the - initializer range. - - Returns: - `keras.initializers.TruncatedNormal`: The truncated normal initializer. - """ - return keras.initializers.TruncatedNormal(stddev=initializer_range) - - @keras_cv_export("keras_cv.models.feature_extractor.QuickGELU") class QuickGELU(keras.layers.Layer): def __init__(self, **kwargs): @@ -54,13 +38,6 @@ def __init__( self.proj_dim = proj_dim self.num_heads = num_heads self.num_hidden_layers = num_hidden_layers - self.fc_std = np.power(2 * self.proj_dim, -0.5) * 0.02 - - self.in_proj_std = ( - np.power(self.proj_dim, -0.5) - * (np.power(2 * self.num_hidden_layers, -0.5)) - * 0.02 - ) self.attn = CLIPAttention( self.proj_dim, self.num_heads, @@ -156,9 +133,14 @@ def __init__(self, width, num_layers, heads, **kwargs): ] def build(self, input_shape): - super().build(input_shape) for block in self.resblocks: block.build(input_shape) + self.built = True + + def compute_output_shape(self, input_shape): + for block in self.resblocks: + input_shape = block.compute_output_shape(input_shape) + return input_shape def call( self, @@ -174,9 +156,6 @@ def call( ) return x - def compute_output_shape(self, inputs_shape): - return inputs_shape - def get_config(self): config = super().get_config() config.update( @@ -213,30 +192,20 @@ def __init__( ) self.scale = self.head_dim**-0.5 - in_proj_std = ( - (self.proj_dim**-0.5) - * ((2 * self.num_hidden_layers) ** -0.5) - * 0.02 - ) - out_proj_std = (self.proj_dim**-0.5) * 0.02 self.q_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="q_proj", ) self.k_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="k_proj", ) self.v_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(in_proj_std), name="v_proj", ) self.out_proj = keras.layers.Dense( units=self.proj_dim, - kernel_initializer=get_initializer(out_proj_std), name="out_proj", ) diff --git a/keras_cv/models/feature_extractor/clip/clip_image_model.py b/keras_cv/models/feature_extractor/clip/clip_image_model.py index 663b01cd62..fb4cf71ef4 100644 --- a/keras_cv/models/feature_extractor/clip/clip_image_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_image_model.py @@ -16,10 +16,8 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.feature_extractor.clip.clip_encoder import CLIPEncoder -from keras_cv.models.feature_extractor.clip.clip_encoder import get_initializer -@keras_cv_export("keras_cv.models.feature_extractor.CLIPPatchingAndEmbedding") class CLIPPatchingAndEmbedding(keras.layers.Layer): def __init__( self, width, patch_size, input_resolution, output_dim, **kwargs @@ -33,7 +31,6 @@ def __init__( padding="valid", use_bias=False, data_format="channels_last", - kernel_initializer=get_initializer(0.02), name="patch_embed.embedding", ) self.width = width @@ -42,9 +39,6 @@ def __init__( self.num_patches = ops.power( (self.input_resolution // self.patch_size), 2 ) - self.class_embedding_initializer = get_initializer( - ops.power(self.width, -0.5) * 0.02 - ) self.output_dim = output_dim def build(self, input_shape): @@ -52,7 +46,6 @@ def build(self, input_shape): self.conv1.build(input_shape) self.class_embedding = self.add_weight( shape=((self.width,)), - initializer=self.class_embedding_initializer, name="patch_embed.class_embedding", ) @@ -67,6 +60,13 @@ def build(self, input_shape): name="patch_embed.positional_embedding", ) + def compute_output_shape(self, input_shape): + return [ + None, + (self.input_resolution // self.patch_size) ** 2 + 1, + self.width, + ] + def call(self, x): batch_size = ops.shape(x)[0] patch_embeddings = self.conv1(x) # shape = [*, grid, grid, channel] @@ -143,12 +143,15 @@ def __init__( ) def build(self, input_shape): - super().build(input_shape) self.embeddings.build(input_shape) self.pre_norm.build([None, None, self.width]) self.encoder.build(None) self.post_norm.build([None, self.width]) - self.image_projector.build([None, None, self.width]) + self.image_projector.build([None, self.width]) + self.built = True + + def compute_output_shape(self, input_shape): + return [input_shape[0], self.output_dim] def call(self, image): x = self.embeddings(image) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 860739388e..bc5e29097f 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -34,6 +34,41 @@ keras_nlp = None +class CLIPHead(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_shape): + self.logit_scale = self.add_variable( + shape=(), + initializer=lambda *a, **kw: ops.log(1 / 0.07), + trainable=True, + dtype=self.variable_dtype, + name="logit_scale", + ) + self.built = True + + def call(self, image_embeddings, text_embeddings): + normalize_image_features = ops.sqrt( + ops.sum(ops.power(image_embeddings, 2), keepdims=True) + ) + normalize_text_features = ops.sqrt( + ops.sum(ops.power(text_embeddings, 2), keepdims=True) + ) + image_embeddings = image_embeddings / normalize_image_features + text_embeddings = text_embeddings / normalize_text_features + logit_scale = ops.exp(self.logit_scale) + image_logits = ( + ops.matmul( + image_embeddings, + ops.transpose(text_embeddings), + ) + * logit_scale + ) + text_logits = ops.transpose(image_logits) + return image_logits, text_logits + + @keras_cv_export(["keras_cv.models.CLIP"]) class CLIP(Task): """ @@ -61,25 +96,27 @@ class CLIP(Task): transformer-based text encoder. transformer_layers (int): The number of layers in the transformer-based text encoder. + Example: + ```python processor = CLIPProcessor( - input_resolution=224, - "path_to_vocab.json", - "path_to_merges.txt" - ) + input_resolution=224, + "path_to_vocab.json", + "path_to_merges.txt" + ) processed_image = processor.process_images(["cat.jpg"]) - processed_text, attention_mask = processor.process_texts( - ["mountains", "cat on tortoise", "two cats"] - ) + tokens = processor( + ["mountains", "cat on tortoise", "two cats"] + ) model = CLIP.from_preset("clip-vit-base-patch16") image_logits, text_logits = model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) + { + "images": processed_image, + "token_ids": tokens["token_ids"], + "padding_mask": tokens["padding_mask"], + } + ) ``` """ @@ -97,12 +134,70 @@ def __init__( transformer_layers=12, **kwargs, ): - super().__init__(**kwargs) if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " "using pip `pip install -U keras-nlp && pip install -U keras`" ) + + vision_heads = vision_width // 64 + + images = keras.Input( + shape=[image_resolution, image_resolution, 3], name="images" + ) + token_ids = keras.Input( + shape=[ + context_length, + ], + name="token_ids", + ) + padding_mask = keras.Input( + shape=[ + context_length, + ], + name="padding_mask", + ) + + image_encoder = CLIPImageEncoder( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + num_layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + name="image_encoder", + ) + text_encoder = CLIPTextEncoder( + transformer_width=transformer_width, + transformer_layers=transformer_layers, + transformer_heads=transformer_heads, + vocab_size=vocab_size, + embed_dim=embed_dim, + context_length=context_length, + name="text_encoder", + ) + clip_head = CLIPHead(name="clip_head") + + image_embeddings = image_encoder(images) + text_embeddings = text_encoder(token_ids, attention_mask=padding_mask) + image_logits, text_logits = clip_head(image_embeddings, text_embeddings) + + inputs = { + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask, + } + outputs = { + "image_logits": image_logits, + "text_logits": text_logits, + } + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + self.embed_dim = embed_dim self.image_resolution = image_resolution self.vision_layers = vision_layers @@ -113,75 +208,9 @@ def __init__( self.transformer_width = transformer_width self.transformer_heads = transformer_heads self.transformer_layers = transformer_layers - - vision_heads = self.vision_width // 64 - self.image_encoder = CLIPImageEncoder( - input_resolution=self.image_resolution, - patch_size=self.vision_patch_size, - width=self.vision_width, - num_layers=self.vision_layers, - heads=vision_heads, - output_dim=self.embed_dim, - name="image_encoder", - ) - self.text_encoder = CLIPTextEncoder( - transformer_width=self.transformer_width, - transformer_layers=self.transformer_layers, - transformer_heads=self.transformer_heads, - vocab_size=self.vocab_size, - embed_dim=self.embed_dim, - context_length=self.context_length, - name="text_encoder", - ) - - self.logit_scale = keras.Variable( - ops.ones([]) * ops.log(1 / 0.07), name="logit_scale" - ) - self.image_embeddings = None - self.text_embeddings = None - - def build(self, input_shape): - super().build(input_shape) - self.text_encoder.build([None, self.context_length]) - self.image_encoder.build( - [None, self.image_resolution, self.image_resolution, 3] - ) - - def encode_images(self, image): - return self.image_encoder(image) - - def encode_text(self, text, attention_mask=None): - return self.text_encoder(text, attention_mask=attention_mask) - - def call(self, inputs): - image, text = inputs["image"], inputs["text"] - if "attention_mask" in inputs: - attention_mask = inputs["attention_mask"] - else: - attention_mask = None - self.image_embeddings = self.encode_images(image) - self.text_embeddings = self.encode_text( - text, attention_mask=attention_mask - ) - normalize_image_features = ops.sqrt( - ops.sum(ops.power(self.image_embeddings, 2), keepdims=True) - ) - normalize_text_features = ops.sqrt( - ops.sum(ops.power(self.text_embeddings, 2), keepdims=True) - ) - self.image_embeddings = self.image_embeddings / normalize_image_features - self.text_embeddings = self.text_embeddings / normalize_text_features - logit_scale = ops.exp(self.logit_scale) - logits_per_image = ( - ops.matmul( - self.image_embeddings, - ops.transpose(self.text_embeddings), - ) - * logit_scale - ) - logits_per_text = ops.transpose(logits_per_image) - - return logits_per_image, logits_per_text + self.image_encoder = image_encoder + self.text_encoder = text_encoder + self.clip_head = clip_head @classproperty def presets(cls): diff --git a/keras_cv/models/feature_extractor/clip/clip_model_test.py b/keras_cv/models/feature_extractor/clip/clip_model_test.py index a5541303c9..f9126aa9f2 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model_test.py +++ b/keras_cv/models/feature_extractor/clip/clip_model_test.py @@ -35,105 +35,89 @@ ) +@pytest.mark.skipif( + not keras_3(), + reason="Only works with Keras 3", +) class CLIPTest(TestCase): + @pytest.mark.large def test_clip_model_golden_values(self): model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - image_logits, text_logits = model( + outputs = model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose(image_logits, [[1.896712, 1.896712, 1.896712]]) + + # These values are NOT computing using HF as the reference model. + # Currently, the numerics of the CLIP model don't match the + # HF model exactly (for the same inputs). For the time being, + # these tests just confirm that unrelated changed don't affect + # the numerics. Once the fix for the numerics is in, we can remove + # this comment and the xfail below. + self.assertAllClose( + outputs["image_logits"], [[10.246354, 10.246353, 10.246354]] + ) self.assertAllClose( - text_logits, ops.transpose([[1.896712, 1.896712, 1.896712]]) + outputs["text_logits"], + ops.transpose([[10.246354, 10.246353, 10.246354]]), ) + # True reference values computed using HF: + # image_logits: [[17.8013, 17.8013, 17.8013]] + # text_logits: image_logits.T + + # xfail after assertion + pytest.xfail("KerasCV CLIP doesn't match the HF model.") + def test_clip_preprocessor(self): - processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) - processed_text, attention_mask = processor.process_texts( - ["mountains", "cat on tortoise"] - ) + processor = CLIPProcessor(VOCAB_PATH, MERGE_PATH) + tokens = processor(["mountains", "cat on tortoise"]) self.assertAllClose( - processed_text[:, :3], [[49406, 5873, 49407], [49406, 2368, 525]] + tokens["token_ids"][:, :3], + [[49406, 5873, 49407], [49406, 2368, 525]], ) self.assertAllClose( - attention_mask[0, :5], [True, True, True, False, False] + tokens["padding_mask"][0, :5], [True, True, True, False, False] ) def test_clip_preprocessor_tf_data(self): - processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH) + processor = CLIPProcessor(VOCAB_PATH, MERGE_PATH) text_input = ["a bus", "a dog", "a cat"] dataset = tf_data.Dataset.from_tensor_slices(text_input) - dataset.map(processor.process_texts) + dataset.map(processor) @pytest.mark.large def test_presets(self): - # self.skipTest("TODO: Enable after Kaggle model is public") - model = CLIP.from_preset("clip-vit-base-patch16") - processed_image = np.ones(shape=[1, 224, 224, 3]) - processed_text = np.ones(shape=[3, 77]) - attention_mask = np.ones(shape=[3, 77]) - image_logits, text_logits = model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) - - @pytest.mark.large - def test_image_encoder_golden_values(self): model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose( - model.image_embeddings[:, :5], - [[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]], - ) - - @pytest.mark.large - def test_text_encoder_golden_values(self): - model = CLIP() - processed_image = np.ones(shape=[1, 224, 224, 3]) - processed_text = np.ones(shape=[3, 77]) - attention_mask = np.ones(shape=[3, 77]) - model( - { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, - } - ) - self.assertAllClose( - model.text_embeddings[0, :3], - [0.007531, -0.038361, -0.035686], - ) @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self): - model = CLIP() + model = CLIP.from_preset("clip-vit-base-patch32") processed_image = np.ones(shape=[1, 224, 224, 3]) processed_text = np.ones(shape=[3, 77]) attention_mask = np.ones(shape=[3, 77]) - model_output, _ = model( + outputs = model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) save_path = os.path.join(self.get_temp_dir(), "model.keras") @@ -146,11 +130,11 @@ def test_saved_model(self): # Check we got the real object back. self.assertIsInstance(restored_model, CLIP) # Check that output matches. - restored_output, _ = restored_model( + restored_outputs = restored_model( { - "image": processed_image, - "text": processed_text, - "attention_mask": attention_mask, + "images": processed_image, + "token_ids": processed_text, + "padding_mask": attention_mask, } ) - self.assertAllClose(model_output, restored_output) + self.assertAllClose(outputs, restored_outputs) diff --git a/keras_cv/models/feature_extractor/clip/clip_presets.py b/keras_cv/models/feature_extractor/clip/clip_presets.py index 656c9ad8ed..f148aa2a28 100644 --- a/keras_cv/models/feature_extractor/clip/clip_presets.py +++ b/keras_cv/models/feature_extractor/clip/clip_presets.py @@ -28,7 +28,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch16/7", }, "clip-vit-base-patch32": { "metadata": { @@ -44,7 +44,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-base-patch32/6", }, "clip-vit-large-patch14": { "metadata": { @@ -60,7 +60,7 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/4", + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14/6", }, "clip-vit-large-patch14-336": { "metadata": { @@ -76,6 +76,6 @@ "official_name": "CLIP", "path": "clip", }, - "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/4", # noqa: E501 + "kaggle_handle": "kaggle://keras/clip/keras/clip-vit-large-patch14-336/6", # noqa: E501 }, } diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 16d8d24222..7714294769 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -12,9 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tensorflow as tf +import tree + from keras_cv.api_export import keras_cv_export +from keras_cv.backend import config from keras_cv.backend import keras -from keras_cv.backend import ops +from keras_cv.models.feature_extractor.clip.clip_processor_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_cv.models.feature_extractor.clip.clip_processor_utils import ( + convert_to_backend_tensor_or_python_list, +) from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer try: @@ -25,10 +34,10 @@ @keras_cv_export("keras_cv.models.feature_extractor.CLIPProcessor") -class CLIPProcessor: +class CLIPProcessor(keras.layers.Layer): """ CLIPProcessor is a utility class that provides functionality for processing - images and texts in the context of the CLIP (Contrastive Language-Image + texts in the context of the CLIP (Contrastive Language-Image Pretraining) model. Args: @@ -39,31 +48,27 @@ class CLIPProcessor: should be the file path to merge rules. The merge rule file should have one merge rule per line. - Methods: - process_images(image_path: List[str]): Transforms an image located at - the specified path. - - process_texts(texts: Union[str, List[str]], context_length: int = 77): - Processes a single text or a list of texts, returning packed token - sequences. - """ - def __init__(self, input_resolution, vocabulary, merges, **kwargs): + def __init__(self, vocabulary, merges, **kwargs): + super().__init__(**kwargs) if keras_nlp is None: raise ValueError( "ClipTokenizer requires keras-nlp. Please install " "using pip `pip install -U keras-nlp && pip install -U keras`" ) - self.input_resolution = input_resolution self.vocabulary = vocabulary self.merges = merges - self.image_transform = self.transform_image self.tokenizer = CLIPTokenizer( vocabulary=self.vocabulary, merges=self.merges, - unsplittable_tokens=[""], ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. self.packer = StartEndPacker( start_value=self.tokenizer.token_to_id("<|startoftext|>"), end_value=self.tokenizer.token_to_id("<|endoftext|>"), @@ -71,69 +76,58 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs): sequence_length=77, return_padding_mask=True, ) + self.built = True - def transform_image(self, image_path): - input_resolution = self.input_resolution - mean = ops.array([0.48145466, 0.4578275, 0.40821073]) - std = ops.array([0.26862954, 0.26130258, 0.27577711]) - - image = keras.utils.load_img(image_path) - image = keras.utils.img_to_array(image) - image = ( - ops.image.resize( - image, - (input_resolution, input_resolution), - interpolation="bicubic", - ) - / 255.0 - ) - central_fraction = input_resolution / image.shape[0] - width, height = image.shape[0], image.shape[1] - left = ops.cast((width - width * central_fraction) / 2, dtype="int32") - top = ops.cast((height - height * central_fraction) / 2, dtype="int32") - right = ops.cast((width + width * central_fraction) / 2, dtype="int32") - bottom = ops.cast( - (height + height * central_fraction) / 2, dtype="int32" - ) + def _process_texts(self, texts, context_length: int = 77): + # Ensure the layer is built + if not self.built: + self.build(None) - image = ops.slice( - image, [left, top, 0], [right - left, bottom - top, 3] - ) + texts = convert_inputs_to_list_of_tensor_segments(texts) - image = (image - mean) / std - return image + if len(texts) != 1: + raise ValueError( + "CLIP requires each input feature to contain only " + f"one segment, but received {len(texts)}." + ) - def process_images(self, images): - if isinstance(images, str): - images = [images] + token_ids, padding_mask = self.packer( + self.tokenizer(texts[0]), + sequence_length=context_length, + add_start_value=True, + add_end_value=True, + ) + return {"token_ids": token_ids, "padding_mask": padding_mask} - def process_image(image): - if isinstance(image, str): - return self.image_transform(image) + def call(self, texts, context_length: int = 77): + return self._process_texts(texts, context_length=context_length) - processed_images = list(map(process_image, images)) - processed_images = ops.stack(processed_images) - return processed_images + def get_build_config(self): + return None - def process_texts(self, texts, context_length: int = 77): - if isinstance(texts, str): - texts = [texts] + def __call__(self, *args, **kwargs): + # Always place on CPU for preprocessing, to avoid expensive back and + # forth copies to GPU before the trainable model. + with tf.device("cpu"): + outputs = super().__call__(*args, **kwargs) - def pack_tokens(text): - return self.packer( - self.tokenizer(text), - sequence_length=context_length, - add_start_value=True, - add_end_value=True, - ) + # Jax and Torch lack native string and ragged types. + # If we are running on those backends and not running with tf.data + # (we are outside a tf.function), we covert all ragged and string + # tensor to pythonic types. + is_tf_backend = config.backend() == "tensorflow" + is_in_tf_graph = not tf.executing_eagerly() + if not is_tf_backend and not is_in_tf_graph: + outputs = tree.map_structure( + convert_to_backend_tensor_or_python_list, outputs + ) - return pack_tokens(texts) + return outputs def get_config(self): config = super().get_config() config.update( { - "input_resolution": self.input_resolution, "vocabulary": self.vocabulary, "merges": self.merges, } diff --git a/keras_cv/models/feature_extractor/clip/clip_processor_utils.py b/keras_cv/models/feature_extractor/clip/clip_processor_utils.py new file mode 100644 index 0000000000..919cb316d0 --- /dev/null +++ b/keras_cv/models/feature_extractor/clip/clip_processor_utils.py @@ -0,0 +1,110 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + +from keras_cv.backend import ops + + +def _decode_strings_to_utf8(inputs): + """Recursively decodes to list of strings with 'utf-8' encoding.""" + if isinstance(inputs, bytes): + # Handles the case when the input is a scalar string. + return inputs.decode("utf-8", errors="ignore") + else: + # Recursively iterate when input is a list. + return [_decode_strings_to_utf8(x) for x in inputs] + + +def tensor_to_list(inputs): + """Converts a tensor to nested lists. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + """ + if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): + inputs = tf.convert_to_tensor(inputs) + if isinstance(inputs, tf.RaggedTensor): + list_outputs = inputs.to_list() + elif isinstance(inputs, tf.Tensor): + list_outputs = inputs.numpy() + if inputs.shape.rank != 0: + list_outputs = list_outputs.tolist() + if inputs.dtype == tf.string: + list_outputs = _decode_strings_to_utf8(list_outputs) + return list_outputs + + +def convert_to_backend_tensor_or_python_list(x): + """ + Convert a tensor to the backend friendly representation of the data. + + This wraps `ops.convert_to_tensor` to account for the fact that torch and + jax both lack native types for ragged and string data. + + If we encounter one of these types in torch or jax, we will instead covert + the tensor to simple pythonic types (lists of strings). + """ + if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string: + return tensor_to_list(x) + return ops.convert_to_tensor(x) + + +def convert_inputs_to_list_of_tensor_segments(x): + """Converts user inputs to a list of a tensor segments. + + For models and layers which accept lists of string tensors to pack together, + this method converts user inputs to a uniform format in a way that can be + considered canonical for the library. + + We handle the following: + + - A single string will be converted to a tensor and wrapped in a list. + - A list of strings will be converted to a tensor and wrapped in a list. + - A single tensor will be wrapped in a list. + - A list of tensors will be passed through unaltered. + + All other inputs will result in an error. This effectively means that users + who would like to pack multiple segments together should convert those + segments to tensors before calling the layer. This removes any ambiguity + in the input for those cases. + """ + # Check the input type. + is_string = isinstance(x, (str, bytes)) + is_tensor = hasattr(x, "__array__") + is_string_list = ( + isinstance(x, (list, tuple)) and x and isinstance(x[0], (str, bytes)) + ) + is_tensor_list = ( + isinstance(x, (list, tuple)) and x and hasattr(x[0], "__array__") + ) + + if is_string or is_string_list: + # Automatically convert raw strings or string lists to tensors. + # Wrap this input as a single (possibly batched) segment. + x = [tf.convert_to_tensor(x)] + elif is_tensor: + # Automatically wrap a single tensor as a single segment. + x = [x] + elif is_tensor_list: + # Pass lists of tensors though unaltered. + x = x + else: + # Error for all other input. + raise ValueError( + f"Unsupported input for `x`. `x` should be a string, a list of " + "strings, or a list of tensors. If passing multiple segments " + "which should packed together, please convert your inputs to a " + f"list of tensors. Received `x={x}`" + ) + return x diff --git a/keras_cv/models/feature_extractor/clip/clip_text_model.py b/keras_cv/models/feature_extractor/clip/clip_text_model.py index 26ab92cb0a..1259761d58 100644 --- a/keras_cv/models/feature_extractor/clip/clip_text_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_text_model.py @@ -62,12 +62,15 @@ def __init__( ) def build(self, input_shape): - super().build(input_shape) self.token_embedding.build(input_shape) self.positional_embedding.build([1, self.context_length]) self.encoder.build(None) self.ln_final.build([None, None, self.transformer_width]) - self.text_projector.build([None, None, self.transformer_width]) + self.text_projector.build([None, self.transformer_width]) + self.built = True + + def compute_output_shape(self, input_shape): + return [input_shape[0], self.embed_dim] def call(self, inputs, attention_mask=None): token_embedding = self.token_embedding(inputs) @@ -76,7 +79,7 @@ def call(self, inputs, attention_mask=None): ) position_embedding = self.positional_embedding(position_ids) position_embedding = ops.tile( - position_embedding, repeats=(inputs.shape[0], 1, 1) + position_embedding, repeats=(ops.shape(inputs)[0], 1, 1) ) causal_attention_mask = ops.ones( (self.context_length, self.context_length) diff --git a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb index ff3bb4c991..b8f556343e 100644 --- a/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb +++ b/keras_cv/tools/checkpoint_conversion/clip_weights_conversion.ipynb @@ -3,150 +3,64 @@ { "cell_type": "markdown", "metadata": { - "id": "0DhV6hzOMY0W" + "id": "mdGT8Em4Mc4b" }, "source": [ - "# Setup" + "# Import" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cRzYR-oFgxt1", - "outputId": "80b8db20-da09-43bd-9b70-fad93b1e1ca1" + "id": "0mtj1abS2cVf" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for keras-cv (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m465.2/465.2 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m36.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0m" - ] - } - ], + "outputs": [], "source": [ - "!pip install -q git+https://github.com/divyashreepathihalli/keras-cv.git@CLIP_refactor\n", - "!pip install -q keras-nlp\n", - "!pip install -q tf-keras\n", - "!pip install -q tensorflow-text\n", - "!pip install -q keras==3.0.2" + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nuFgha2jTshi", - "outputId": "63d4160e-42b3-4f6b-e672-ba30c9402d25" + "id": "GDvJmQuug4-x" }, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "--2024-02-21 20:54:06-- https://i.imgur.com/8H7XCH0.jpg\n", - "Resolving i.imgur.com (i.imgur.com)... 146.75.76.193\n", - "Connecting to i.imgur.com (i.imgur.com)|146.75.76.193|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 44544 (44K) [image/jpeg]\n", - "Saving to: ‘cat.jpg’\n", - "\n", - "\rcat.jpg 0%[ ] 0 --.-KB/s \rcat.jpg 100%[===================>] 43.50K --.-KB/s in 0.01s \n", - "\n", - "2024-02-21 20:54:06 (4.16 MB/s) - ‘cat.jpg’ saved [44544/44544]\n", - "\n", - "--2024-02-21 20:54:06-- http://images.cocodataset.org/val2017/000000039769.jpg\n", - "Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.206.137, 16.182.42.89, 54.231.201.177, ...\n", - "Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.206.137|:80... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 173131 (169K) [image/jpeg]\n", - "Saving to: ‘two_cats.jpg’\n", - "\n", - "two_cats.jpg 100%[===================>] 169.07K --.-KB/s in 0.09s \n", - "\n", - "2024-02-21 20:54:07 (1.77 MB/s) - ‘two_cats.jpg’ saved [173131/173131]\n", - "\n", - "--2024-02-21 20:54:07-- https://i.imgur.com/PpgZzP4.jpeg\n", - "Resolving i.imgur.com (i.imgur.com)... 146.75.76.193\n", - "Connecting to i.imgur.com (i.imgur.com)|146.75.76.193|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1610285 (1.5M) [image/jpeg]\n", - "Saving to: ‘mountain.jpg’\n", - "\n", - "mountain.jpg 100%[===================>] 1.54M --.-KB/s in 0.06s \n", - "\n", - "2024-02-21 20:54:07 (27.6 MB/s) - ‘mountain.jpg’ saved [1610285/1610285]\n", - "\n" + "2024-04-08 22:25:42.921695: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-04-08 22:25:42.928541: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-04-08 22:25:43.041846: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-04-08 22:25:44.645915: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ - "!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg\n", - "!wget http://images.cocodataset.org/val2017/000000039769.jpg -O two_cats.jpg\n", - "!wget https://i.imgur.com/PpgZzP4.jpeg -O mountain.jpg" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mdGT8Em4Mc4b" - }, - "source": [ - "# Import" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0mtj1abS2cVf" - }, - "outputs": [], - "source": [ - "import os\n", + "import json\n", + "from datetime import datetime\n", "\n", - "os.environ[\"KERAS_BACKEND\"] = \"torch\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GDvJmQuug4-x" - }, - "outputs": [], - "source": [ - "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", "import keras\n", + "from keras import ops\n", + "\n", + "import keras_cv\n", + "from keras_cv.models.feature_extractor.clip import CLIPProcessor\n", "from keras_cv.models import CLIP" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "X3kkmK6h_gFH" }, "outputs": [], "source": [ - "# @title Select which model weights you would like to convert\n", "MODEL_CONFIGS = {\n", " \"CLIP_B32\": {\n", " \"embed_dim\": 512,\n", @@ -203,7 +117,7 @@ " \"CLIP_L14\": \"openai/clip-vit-large-patch14\",\n", " \"CLIP_L14_336\": \"openai/clip-vit-large-patch14-336\",\n", "}\n", - "config_name = \"CLIP_L14_336\" # @param [\"CLIP_B16\", \"CLIP_B32\", \"CLIP_L14\", \"CLIP_L14_336\"]\n", + "config_name = \"CLIP_B16\"\n", "config_name_hf = model_map_hf[config_name]" ] }, @@ -218,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "urhuhwq0Dczo" }, @@ -250,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -276,170 +190,51 @@ { "data": { "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
-       "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 0 (unbuilt) │\n",
-       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
-       "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 0 (unbuilt) │\n",
-       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 1 (4.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1\u001b[0m (4.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 1 (4.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1\u001b[0m (4.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "buXKlNfGTenW" - }, - "outputs": [], - "source": [ - "processor = CLIPProcessor(\n", - " MODEL_CONFIGS[config_name][\"image_resolution\"], \"vocab.json\", \"merges.txt\"\n", - ")\n", - "image = processor.process_images([\"two_cats.jpg\"])\n", - "text_input = [\"mountains\", \"cat on tortoise\", \"two cats\"]\n", - "text, attention_mask = processor.process_texts(text_input)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BHSpMv0PT5SX" - }, - "outputs": [], - "source": [ - "image_logits, text_logits = model(image, text, attention_mask)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JPn0gACJjKy5", - "outputId": "cbc7313a-4ddd-4021-9e84-fa668987849d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[3.7318, 3.7792, 3.7633]], grad_fn=)" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "image_logits" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 193 - }, - "id": "GgNBvYCTtmA3", - "outputId": "a667a9e5-397e-4299-fdc1-8899621112ad" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
Model: \"clip\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"clip\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                        Output Shape                       Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n",
-       "│ image_encoder (CLIPImageEncoder)   │ ?                             │ 304,293,888 │\n",
-       "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n",
-       "│ text_encoder (CLIPTextEncoder)     │ ?                             │ 123,650,304 │\n",
-       "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n",
+       "
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)         Output Shape          Param #  Connected to      ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ images (InputLayer) │ (None, 224, 224,  │          0 │ -                 │\n",
+       "│                     │ 3)                │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ token_ids           │ (None, 77)        │          0 │ -                 │\n",
+       "│ (InputLayer)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ padding_mask        │ (None, 77)        │          0 │ -                 │\n",
+       "│ (InputLayer)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ image_encoder       │ (None, 512)       │ 86,192,640 │ images[0][0]      │\n",
+       "│ (CLIPImageEncoder)  │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ text_encoder        │ (None, 512)       │ 63,428,096 │ token_ids[0][0],  │\n",
+       "│ (CLIPTextEncoder)   │                   │            │ padding_mask[0][ │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ clip_head           │ [(None, None),    │          1 │ image_encoder[0]… │\n",
+       "│ (CLIPHead)          │ (None, None)]     │            │ text_encoder[0][ │\n",
+       "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
        "
\n" ], "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩\n", - "│ image_encoder (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ ? │ \u001b[38;5;34m304,293,888\u001b[0m │\n", - "├────────────────────────────────────┼───────────────────────────────┼─────────────┤\n", - "│ text_encoder (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ ? │ \u001b[38;5;34m123,650,304\u001b[0m │\n", - "└────────────────────────────────────┴───────────────────────────────┴─────────────┘\n" + "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ │ \u001b[38;5;34m3\u001b[0m) │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ token_ids │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m77\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ padding_mask │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m77\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ image_encoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m86,192,640\u001b[0m │ images[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCLIPImageEncoder\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ text_encoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m63,428,096\u001b[0m │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCLIPTextEncoder\u001b[0m) │ │ │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ clip_head │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m1\u001b[0m │ image_encoder[\u001b[38;5;34m0\u001b[0m]… │\n", + "│ (\u001b[38;5;33mCLIPHead\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ text_encoder[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" ] }, "metadata": {}, @@ -448,11 +243,11 @@ { "data": { "text/html": [ - "
 Total params: 427,944,193 (1.59 GB)\n",
+       "
 Total params: 149,620,737 (570.76 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m427,944,193\u001b[0m (1.59 GB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m149,620,737\u001b[0m (570.76 MB)\n" ] }, "metadata": {}, @@ -461,11 +256,11 @@ { "data": { "text/html": [ - "
 Trainable params: 427,944,193 (1.59 GB)\n",
+       "
 Trainable params: 149,620,737 (570.76 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m427,944,193\u001b[0m (1.59 GB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m149,620,737\u001b[0m (570.76 MB)\n" ] }, "metadata": {}, @@ -500,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "3W2prd6C0pxe" }, @@ -508,6 +303,7 @@ "source": [ "from PIL import Image\n", "import requests\n", + "import torch\n", "\n", "from transformers import CLIPProcessor as CP\n", "from transformers import CLIPModel as CM" @@ -515,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -614,192 +410,12 @@ "id": "EntuvOq1MhwU", "outputId": "cbd7cd77-6d8f-4a76-dae0-24530c12eeb6" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", - "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", - "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", - "You will be able to reuse this secret in all of your notebooks.\n", - "Please note that authentication is recommended but still optional to access public models or datasets.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "46636db47838400cb7407fc2ab0720eb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "config.json: 0%| | 0.00/4.76k [00:00)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "photo = {\n", - " \"cat\": \"https://i.imgur.com/8H7XCH0.jpg\",\n", - " \"two_cats\": \"http://images.cocodataset.org/val2017/000000039769.jpg\",\n", - " \"mountain\": \"https://i.imgur.com/PpgZzP4.jpeg\",\n", - "}\n", - "url = photo[\"cat\"]\n", - "image_hf = Image.open(requests.get(url, stream=True).raw)\n", - "text_inputs = [\"mountains\", \"cat on tortoise\", \"two dogs\"]\n", - "inputs = processor_hf(\n", - " text=text_inputs, images=image_hf, return_tensors=\"pt\", padding=True\n", - ")\n", - "outputs = model_hf(**inputs)\n", - "logits_per_image = (\n", - " outputs.logits_per_image\n", - ") # this is the image-text similarity score\n", - "probs = logits_per_image.softmax(\n", - " dim=1\n", - ") # we can take the softmax to get the label probabilitiesprobs\n", - "logits_per_image" - ] - }, { "cell_type": "markdown", "metadata": { @@ -811,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "id": "wPa0cVnY3cBC" }, @@ -827,18 +443,20 @@ "id": "TUCpKltRG4Gd" }, "source": [ - "##vision encoder" + "## Vision Encoder" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "id": "tn_U02N7U2VN" }, "outputs": [], "source": [ - "model.logit_scale.assign(hf_wts.pop(\"logit_scale\").numpy())\n", + "model.get_layer(\"clip_head\").logit_scale.assign(\n", + " hf_wts.pop(\"logit_scale\").numpy()\n", + ")\n", "model.get_layer(\"image_encoder\").get_layer(\n", " \"clip_patch_embedding\"\n", ").class_embedding.assign(\n", @@ -875,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "YRXC2HZC3FjG" }, @@ -988,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "_1AD7TcbdWEC" }, @@ -1013,7 +631,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "id": "IQFquy9R75G8" }, @@ -1104,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1119,7 +737,7 @@ "odict_keys([])" ] }, - "execution_count": 17, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1135,19 +753,178 @@ "id": "wlfDdO-mid62" }, "source": [ - "# save weights" + "# Save Weights" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "id": "QscCUUZFiqBV" }, "outputs": [], "source": [ - "model.save_weights(\"model.weights.h5\")" + "os.makedirs(config_name, exist_ok=True)\n", + "model.save_weights(os.path.join(config_name, \"model.weights.h5\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"module\": \"keras_cv.models.feature_extractor.clip.clip_model\",\n", + " \"class_name\": \"CLIP\",\n", + " \"config\": model.get_config(),\n", + " \"registered_name\": \"keras_cv>CLIP\",\n", + " \"weights\": \"model.weights.h5\",\n", + "}\n", + "\n", + "with open(os.path.join(config_name, \"config.json\"), \"w\") as config_file:\n", + " json.dump(config, config_file)\n", + "\n", + "metadata = {\n", + " \"keras_version\": keras.__version__,\n", + " \"keras_cv_version\": keras_cv.__version__,\n", + " \"parameter_count\": model.count_params(),\n", + " \"date_saved\": datetime.utcnow().strftime(\"%Y-%m-%d@%H:%M:%S\"),\n", + "}\n", + "\n", + "with open(os.path.join(config_name, \"metadata.json\"), \"w\") as metadata_file:\n", + " json.dump(metadata, metadata_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Verify numerics" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n", + "# image = Image.open(requests.get(url, stream=True).raw)\n", + "\n", + "# inputs = processor_hf(text=[\"a photo of a cat\", \"a photo of a dog\"], images=image, return_tensors=\"pt\", padding=True)\n", + "\n", + "# outputs = model_hf(**inputs)\n", + "# logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n", + "# probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# plt.imshow(image)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# probs" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# VOCAB_PATH = keras.utils.get_file(None, \"https://storage.googleapis.com/keras-cv/models/clip/vocab.json\")\n", + "# MERGE_PATH = keras.utils.get_file(None, \"https://storage.googleapis.com/keras-cv/models/clip/merges.txt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH)\n", + "# text_processed = processor([\"a photo of a cat\", \"a photo of a dog\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# image_processed = ops.convert_to_tensor(inputs['pixel_values'].detach().cpu().permute(0, 2, 3, 1).numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# outputs = model({\n", + "# \"images\": image_processed,\n", + "# **text_processed\n", + "# })" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# ops.softmax(outputs[\"image_logits\"], axis=1) # we can take the softmax to get the label probabilities" ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# model.load_weights(\"model.weights.h5\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# outputs = model({\n", + "# \"images\": image_processed,\n", + "# **text_processed\n", + "# })" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "# ops.softmax(outputs[\"image_logits\"], axis=1) # we can take the softmax to get the label probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1158,11 +935,21 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -3906,5 +3693,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 }