|
| 1 | +# Copyright The Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import Dict, List, Literal, Tuple, Union |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch import Tensor |
| 18 | + |
| 19 | +from torchmetrics.functional.multimodal.clip_score import _get_clip_model_and_processor |
| 20 | +from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout |
| 21 | +from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 |
| 22 | + |
| 23 | +if _TRANSFORMERS_GREATER_EQUAL_4_10: |
| 24 | + from transformers import CLIPModel as _CLIPModel |
| 25 | + from transformers import CLIPProcessor as _CLIPProcessor |
| 26 | + |
| 27 | + def _download_clip() -> None: |
| 28 | + _CLIPModel.from_pretrained("openai/clip-vit-base-patch16") |
| 29 | + _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") |
| 30 | + |
| 31 | + if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): |
| 32 | + __doctest_skip__ = ["clip_score"] |
| 33 | + |
| 34 | +else: |
| 35 | + __doctest_skip__ = ["clip_image_quality_assessment"] |
| 36 | + _CLIPModel = None |
| 37 | + _CLIPProcessor = None |
| 38 | + |
| 39 | +if not _PIQ_GREATER_EQUAL_0_8: |
| 40 | + __doctest_skip__ = ["clip_image_quality_assessment"] |
| 41 | + |
| 42 | +_PROMPTS: Dict[str, Tuple[str, str]] = { |
| 43 | + "quality": ("Good photo.", "Bad photo."), |
| 44 | + "brightness": ("Bright photo.", "Dark photo."), |
| 45 | + "noisiness": ("Clean photo.", "Noisy photo."), |
| 46 | + "colorfullness": ("Colorful photo.", "Dull photo."), |
| 47 | + "sharpness": ("Sharp photo.", "Blurry photo."), |
| 48 | + "contrast": ("High contrast photo.", "Low contrast photo."), |
| 49 | + "complexity": ("Complex photo.", "Simple photo."), |
| 50 | + "natural": ("Natural photo.", "Synthetic photo."), |
| 51 | + "happy": ("Happy photo.", "Sad photo."), |
| 52 | + "scary": ("Scary photo.", "Peaceful photo."), |
| 53 | + "new": ("New photo.", "Old photo."), |
| 54 | + "warm": ("Warm photo.", "Cold photo."), |
| 55 | + "real": ("Real photo.", "Abstract photo."), |
| 56 | + "beutiful": ("Beautiful photo.", "Ugly photo."), |
| 57 | + "lonely": ("Lonely photo.", "Sociable photo."), |
| 58 | + "relaxing": ("Relaxing photo.", "Stressful photo."), |
| 59 | +} |
| 60 | + |
| 61 | + |
| 62 | +def _get_clip_iqa_model_and_processor( |
| 63 | + model_name_or_path: Literal[ |
| 64 | + "clip_iqa", |
| 65 | + "openai/clip-vit-base-patch16", |
| 66 | + "openai/clip-vit-base-patch32", |
| 67 | + "openai/clip-vit-large-patch14-336", |
| 68 | + "openai/clip-vit-large-patch14", |
| 69 | + ] |
| 70 | +) -> Tuple[_CLIPModel, _CLIPProcessor]: |
| 71 | + """Extract the CLIP model and processor from the model name or path.""" |
| 72 | + if model_name_or_path == "clip_iqa": |
| 73 | + if not _PIQ_GREATER_EQUAL_0_8: |
| 74 | + raise ValueError( |
| 75 | + "For metric `clip_iqa` to work with argument `model_name_or_path` set to default value `'clip_iqa'`" |
| 76 | + ", package `piq` version v0.8.0 or later must be installed. Either install with `pip install piq` or" |
| 77 | + "`pip install torchmetrics[multimodal]`" |
| 78 | + ) |
| 79 | + |
| 80 | + import piq |
| 81 | + |
| 82 | + model = piq.clip_iqa.clip.load().eval() |
| 83 | + # any model checkpoint can be used here because the tokenizer is the same for all |
| 84 | + processor = _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") |
| 85 | + return model, processor |
| 86 | + return _get_clip_model_and_processor(model_name_or_path) |
| 87 | + |
| 88 | + |
| 89 | +def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]: |
| 90 | + """Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors. |
| 91 | +
|
| 92 | + Args: |
| 93 | + prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one |
| 94 | + of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one |
| 95 | + of two things: either a string or a tuple of strings. If a string is provided, it must be one of the |
| 96 | + availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a |
| 97 | + positive prompt and the second string must be a negative prompt. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + Tuple containing a list of prompts and a list of the names of the prompts. The first list is double the length |
| 101 | + of the second list. |
| 102 | +
|
| 103 | + Examples:: |
| 104 | +
|
| 105 | + >>> # single prompt |
| 106 | + >>> _clip_iqa_format_prompts(("quality",)) |
| 107 | + (['Good photo.', 'Bad photo.'], ['quality']) |
| 108 | + >>> # multiple prompts |
| 109 | + >>> _clip_iqa_format_prompts(("quality", "brightness")) |
| 110 | + (['Good photo.', 'Bad photo.', 'Bright photo.', 'Dark photo.'], ['quality', 'brightness']) |
| 111 | + >>> # Custom prompts |
| 112 | + >>> _clip_iqa_format_prompts(("quality", ("Super good photo.", "Super bad photo."))) |
| 113 | + (['Good photo.', 'Bad photo.', 'Super good photo.', 'Super bad photo.'], ['quality', 'user_defined_0']) |
| 114 | +
|
| 115 | + """ |
| 116 | + if not isinstance(prompts, tuple): |
| 117 | + raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") |
| 118 | + |
| 119 | + prompts_names: List[str] = [] |
| 120 | + prompts_list: List[str] = [] |
| 121 | + count = 0 |
| 122 | + for p in prompts: |
| 123 | + if not isinstance(p, (str, tuple)): |
| 124 | + raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") |
| 125 | + if isinstance(p, str): |
| 126 | + if p not in _PROMPTS: |
| 127 | + raise ValueError( |
| 128 | + f"All elements of `prompts` must be one of {_PROMPTS.keys()} if not custom tuple promts, got {p}." |
| 129 | + ) |
| 130 | + prompts_names.append(p) |
| 131 | + prompts_list.extend(_PROMPTS[p]) |
| 132 | + if isinstance(p, tuple) and len(p) != 2: |
| 133 | + raise ValueError("If a tuple is provided in argument `prompts`, it must be of length 2") |
| 134 | + if isinstance(p, tuple): |
| 135 | + prompts_names.append(f"user_defined_{count}") |
| 136 | + prompts_list.extend(p) |
| 137 | + count += 1 |
| 138 | + |
| 139 | + return prompts_list, prompts_names |
| 140 | + |
| 141 | + |
| 142 | +def _clip_iqa_get_anchor_vectors( |
| 143 | + model_name_or_path: str, |
| 144 | + model: _CLIPModel, |
| 145 | + processor: _CLIPProcessor, |
| 146 | + prompts_list: List[str], |
| 147 | + device: Union[str, torch.device], |
| 148 | +) -> Tensor: |
| 149 | + """Calculates the anchor vectors for the CLIP IQA metric. |
| 150 | +
|
| 151 | + Args: |
| 152 | + model_name_or_path: string indicating the version of the CLIP model to use. |
| 153 | + model: The CLIP model |
| 154 | + processor: The CLIP processor |
| 155 | + prompts_list: A list of prompts |
| 156 | + device: The device to use for the calculation |
| 157 | +
|
| 158 | + """ |
| 159 | + if model_name_or_path == "clip_iqa": |
| 160 | + text_processed = processor(text=prompts_list) |
| 161 | + anchors_text = torch.zeros( |
| 162 | + len(prompts_list), processor.tokenizer.model_max_length, dtype=torch.long, device=device |
| 163 | + ) |
| 164 | + for i, tp in enumerate(text_processed["input_ids"]): |
| 165 | + anchors_text[i, : len(tp)] = torch.tensor(tp, dtype=torch.long, device=device) |
| 166 | + |
| 167 | + anchors = model.encode_text(anchors_text).float() |
| 168 | + else: |
| 169 | + text_processed = processor(text=prompts_list, return_tensors="pt", padding=True) |
| 170 | + anchors = model.get_text_features( |
| 171 | + text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device) |
| 172 | + ) |
| 173 | + return anchors / anchors.norm(p=2, dim=-1, keepdim=True) |
| 174 | + |
| 175 | + |
| 176 | +def _clip_iqa_update( |
| 177 | + model_name_or_path: str, |
| 178 | + images: Tensor, |
| 179 | + model: _CLIPModel, |
| 180 | + processor: _CLIPProcessor, |
| 181 | + data_range: Union[int, float], |
| 182 | + device: Union[str, torch.device], |
| 183 | +) -> Tensor: |
| 184 | + images = images / float(data_range) |
| 185 | + """Update function for CLIP IQA.""" |
| 186 | + if model_name_or_path == "clip_iqa": |
| 187 | + # default mean and std from clip paper, see: |
| 188 | + # https://github.com/huggingface/transformers/blob/main/src/transformers/utils/constants.py |
| 189 | + default_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1) |
| 190 | + default_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1) |
| 191 | + images = (images - default_mean) / default_std |
| 192 | + img_features = model.encode_image(images.float(), pos_embedding=False).float() |
| 193 | + else: |
| 194 | + processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True) |
| 195 | + img_features = model.get_image_features(processed_input["pixel_values"].to(device)) |
| 196 | + return img_features / img_features.norm(p=2, dim=-1, keepdim=True) |
| 197 | + |
| 198 | + |
| 199 | +def _clip_iqa_compute( |
| 200 | + img_features: Tensor, |
| 201 | + anchors: Tensor, |
| 202 | + prompts_names: List[str], |
| 203 | + format_as_dict: bool = True, |
| 204 | +) -> Union[Tensor, Dict[str, Tensor]]: |
| 205 | + """Final computation of CLIP IQA.""" |
| 206 | + logits_per_image = 100 * img_features @ anchors.t() |
| 207 | + probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0] |
| 208 | + if len(prompts_names) == 1: |
| 209 | + return probs.squeeze() |
| 210 | + if format_as_dict: |
| 211 | + return {p: probs[:, i] for i, p in enumerate(prompts_names)} |
| 212 | + return probs |
| 213 | + |
| 214 | + |
| 215 | +def clip_image_quality_assessment( |
| 216 | + images: Tensor, |
| 217 | + model_name_or_path: Literal[ |
| 218 | + "clip_iqa", |
| 219 | + "openai/clip-vit-base-patch16", |
| 220 | + "openai/clip-vit-base-patch32", |
| 221 | + "openai/clip-vit-large-patch14-336", |
| 222 | + "openai/clip-vit-large-patch14", |
| 223 | + ] = "clip_iqa", |
| 224 | + data_range: Union[int, float] = 1.0, |
| 225 | + prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), |
| 226 | +) -> Union[Tensor, Dict[str, Tensor]]: |
| 227 | + """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. |
| 228 | +
|
| 229 | + The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to |
| 230 | + be able to generate a vector representation of the image and the text that is similar if the image and text are |
| 231 | + semantically similar. |
| 232 | +
|
| 233 | + The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The |
| 234 | + prompts always come in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating |
| 235 | + the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine |
| 236 | + which prompt the image is more similar to. The metric then returns the probability that the image is more similar |
| 237 | + to the first prompt than the second prompt. |
| 238 | +
|
| 239 | + Build in promts are: |
| 240 | + * quality: "Good photo." vs "Bad photo." |
| 241 | + * brightness: "Bright photo." vs "Dark photo." |
| 242 | + * noisiness: "Clean photo." vs "Noisy photo." |
| 243 | + * colorfullness: "Colorful photo." vs "Dull photo." |
| 244 | + * sharpness: "Sharp photo." vs "Blurry photo." |
| 245 | + * contrast: "High contrast photo." vs "Low contrast photo." |
| 246 | + * complexity: "Complex photo." vs "Simple photo." |
| 247 | + * natural: "Natural photo." vs "Synthetic photo." |
| 248 | + * happy: "Happy photo." vs "Sad photo." |
| 249 | + * scary: "Scary photo." vs "Peaceful photo." |
| 250 | + * new: "New photo." vs "Old photo." |
| 251 | + * warm: "Warm photo." vs "Cold photo." |
| 252 | + * real: "Real photo." vs "Abstract photo." |
| 253 | + * beutiful: "Beautiful photo." vs "Ugly photo." |
| 254 | + * lonely: "Lonely photo." vs "Sociable photo." |
| 255 | + * relaxing: "Relaxing photo." vs "Stressful photo." |
| 256 | +
|
| 257 | + Args: |
| 258 | + images: Either a single ``[N, C, H, W]`` tensor or a list of ``[C, H, W]`` tensors |
| 259 | + model_name_or_path: string indicating the version of the CLIP model to use. By default this argument is set to |
| 260 | + ``clip_iqa`` which corresponds to the model used in the original paper. Other availble models are |
| 261 | + `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` |
| 262 | + and `"openai/clip-vit-large-patch14"` |
| 263 | + data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255], |
| 264 | + data_range should be 255. The images are normalized by this value. |
| 265 | + prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one |
| 266 | + of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one |
| 267 | + of two things: either a string or a tuple of strings. If a string is provided, it must be one of the |
| 268 | + availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a |
| 269 | + positive prompt and the second string must be a negative prompt. |
| 270 | +
|
| 271 | + .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with |
| 272 | + `pip install piq` or `pip install torchmetrics[multimodal]`. |
| 273 | +
|
| 274 | + Returns: |
| 275 | + A tensor of shape ``(N,)`` if a single promts is provided. If a list of promts is provided, a dictionary of |
| 276 | + with the promts as keys and tensors of shape ``(N,)`` as values. |
| 277 | +
|
| 278 | + Raises: |
| 279 | + ModuleNotFoundError: |
| 280 | + If transformers package is not installed or version is lower than 4.10.0 |
| 281 | + ValueError: |
| 282 | + If not all images have format [C, H, W] |
| 283 | + ValueError: |
| 284 | + If promts is a tuple and it is not of length 2 |
| 285 | + ValueError: |
| 286 | + If promts is a string and it is not one of the available promts |
| 287 | + ValueError: |
| 288 | + If promts is a list of strings and not all strings are one of the available promts |
| 289 | +
|
| 290 | + Example:: |
| 291 | + Single promt: |
| 292 | +
|
| 293 | + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment |
| 294 | + >>> import torch |
| 295 | + >>> _ = torch.manual_seed(42) |
| 296 | + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() |
| 297 | + >>> clip_image_quality_assessment(imgs, prompts=("quality",)) |
| 298 | + tensor([0.8894, 0.8902]) |
| 299 | +
|
| 300 | + Example:: |
| 301 | + Multiple promts: |
| 302 | +
|
| 303 | + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment |
| 304 | + >>> import torch |
| 305 | + >>> _ = torch.manual_seed(42) |
| 306 | + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() |
| 307 | + >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness")) |
| 308 | + {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])} |
| 309 | +
|
| 310 | + Example:: |
| 311 | + Custom promts. Must always be a tuple of length 2, with a positive and negative prompt. |
| 312 | +
|
| 313 | + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment |
| 314 | + >>> import torch |
| 315 | + >>> _ = torch.manual_seed(42) |
| 316 | + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() |
| 317 | + >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness")) |
| 318 | + {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])} |
| 319 | +
|
| 320 | + """ |
| 321 | + prompts_list, prompts_names = _clip_iqa_format_prompts(prompts) |
| 322 | + |
| 323 | + model, processor = _get_clip_iqa_model_and_processor(model_name_or_path) |
| 324 | + device = images.device |
| 325 | + model = model.to(device) |
| 326 | + |
| 327 | + with torch.inference_mode(): |
| 328 | + anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device) |
| 329 | + img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device) |
| 330 | + return _clip_iqa_compute(img_features, anchors, prompts_names) |
0 commit comments