Skip to content

Commit 29f3289

Browse files
SkafteNickistancldpre-commit-ci[bot]
authored
New metric: CLIP IQA (#1931)
Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9fdd57c commit 29f3289

File tree

15 files changed

+888
-19
lines changed

15 files changed

+888
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3838
- Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001))
3939

4040

41+
- Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931))
42+
43+
4144
### Changed
4245

4346
-

docs/source/links.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@
144144
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
145145
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
146146
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
147+
.. _CLIP-IQA: https://arxiv.org/abs/2207.12396
148+
.. _CLIP: https://arxiv.org/abs/2103.00020
147149
.. _PPL : https://arxiv.org/pdf/1812.04948
148150
.. _CIOU: https://arxiv.org/abs/2005.03572
149151
.. _DIOU: https://arxiv.org/abs/1911.08287v1
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
.. customcarditem::
2+
:header: CLIP IQA
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
4+
:tags: Image
5+
6+
.. include:: ../links.rst
7+
8+
########################################
9+
CLIP Image Quality Assessment (CLIP-IQA)
10+
########################################
11+
12+
Module Interface
13+
________________
14+
15+
.. autoclass:: torchmetrics.multimodal.CLIPImageQualityAssessment
16+
:noindex:
17+
:exclude-members: update, compute
18+
19+
20+
Functional Interface
21+
____________________
22+
23+
.. autofunction:: torchmetrics.functional.multimodal.clip_image_quality_assessment
24+
:noindex:

requirements/multimodal.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
transformers >=4.10.0, <4.30.3
5+
piq <=0.8.0

src/torchmetrics/functional/multimodal/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10
1515

1616
if _TRANSFORMERS_GREATER_EQUAL_4_10:
17+
from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment
1718
from torchmetrics.functional.multimodal.clip_score import clip_score
1819

19-
__all__ = ["clip_score"]
20+
__all__ = ["clip_score", "clip_image_quality_assessment"]
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict, List, Literal, Tuple, Union
15+
16+
import torch
17+
from torch import Tensor
18+
19+
from torchmetrics.functional.multimodal.clip_score import _get_clip_model_and_processor
20+
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
21+
from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10
22+
23+
if _TRANSFORMERS_GREATER_EQUAL_4_10:
24+
from transformers import CLIPModel as _CLIPModel
25+
from transformers import CLIPProcessor as _CLIPProcessor
26+
27+
def _download_clip() -> None:
28+
_CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
29+
_CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
30+
31+
if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
32+
__doctest_skip__ = ["clip_score"]
33+
34+
else:
35+
__doctest_skip__ = ["clip_image_quality_assessment"]
36+
_CLIPModel = None
37+
_CLIPProcessor = None
38+
39+
if not _PIQ_GREATER_EQUAL_0_8:
40+
__doctest_skip__ = ["clip_image_quality_assessment"]
41+
42+
_PROMPTS: Dict[str, Tuple[str, str]] = {
43+
"quality": ("Good photo.", "Bad photo."),
44+
"brightness": ("Bright photo.", "Dark photo."),
45+
"noisiness": ("Clean photo.", "Noisy photo."),
46+
"colorfullness": ("Colorful photo.", "Dull photo."),
47+
"sharpness": ("Sharp photo.", "Blurry photo."),
48+
"contrast": ("High contrast photo.", "Low contrast photo."),
49+
"complexity": ("Complex photo.", "Simple photo."),
50+
"natural": ("Natural photo.", "Synthetic photo."),
51+
"happy": ("Happy photo.", "Sad photo."),
52+
"scary": ("Scary photo.", "Peaceful photo."),
53+
"new": ("New photo.", "Old photo."),
54+
"warm": ("Warm photo.", "Cold photo."),
55+
"real": ("Real photo.", "Abstract photo."),
56+
"beutiful": ("Beautiful photo.", "Ugly photo."),
57+
"lonely": ("Lonely photo.", "Sociable photo."),
58+
"relaxing": ("Relaxing photo.", "Stressful photo."),
59+
}
60+
61+
62+
def _get_clip_iqa_model_and_processor(
63+
model_name_or_path: Literal[
64+
"clip_iqa",
65+
"openai/clip-vit-base-patch16",
66+
"openai/clip-vit-base-patch32",
67+
"openai/clip-vit-large-patch14-336",
68+
"openai/clip-vit-large-patch14",
69+
]
70+
) -> Tuple[_CLIPModel, _CLIPProcessor]:
71+
"""Extract the CLIP model and processor from the model name or path."""
72+
if model_name_or_path == "clip_iqa":
73+
if not _PIQ_GREATER_EQUAL_0_8:
74+
raise ValueError(
75+
"For metric `clip_iqa` to work with argument `model_name_or_path` set to default value `'clip_iqa'`"
76+
", package `piq` version v0.8.0 or later must be installed. Either install with `pip install piq` or"
77+
"`pip install torchmetrics[multimodal]`"
78+
)
79+
80+
import piq
81+
82+
model = piq.clip_iqa.clip.load().eval()
83+
# any model checkpoint can be used here because the tokenizer is the same for all
84+
processor = _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
85+
return model, processor
86+
return _get_clip_model_and_processor(model_name_or_path)
87+
88+
89+
def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]:
90+
"""Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors.
91+
92+
Args:
93+
prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
94+
of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one
95+
of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
96+
availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
97+
positive prompt and the second string must be a negative prompt.
98+
99+
Returns:
100+
Tuple containing a list of prompts and a list of the names of the prompts. The first list is double the length
101+
of the second list.
102+
103+
Examples::
104+
105+
>>> # single prompt
106+
>>> _clip_iqa_format_prompts(("quality",))
107+
(['Good photo.', 'Bad photo.'], ['quality'])
108+
>>> # multiple prompts
109+
>>> _clip_iqa_format_prompts(("quality", "brightness"))
110+
(['Good photo.', 'Bad photo.', 'Bright photo.', 'Dark photo.'], ['quality', 'brightness'])
111+
>>> # Custom prompts
112+
>>> _clip_iqa_format_prompts(("quality", ("Super good photo.", "Super bad photo.")))
113+
(['Good photo.', 'Bad photo.', 'Super good photo.', 'Super bad photo.'], ['quality', 'user_defined_0'])
114+
115+
"""
116+
if not isinstance(prompts, tuple):
117+
raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")
118+
119+
prompts_names: List[str] = []
120+
prompts_list: List[str] = []
121+
count = 0
122+
for p in prompts:
123+
if not isinstance(p, (str, tuple)):
124+
raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")
125+
if isinstance(p, str):
126+
if p not in _PROMPTS:
127+
raise ValueError(
128+
f"All elements of `prompts` must be one of {_PROMPTS.keys()} if not custom tuple promts, got {p}."
129+
)
130+
prompts_names.append(p)
131+
prompts_list.extend(_PROMPTS[p])
132+
if isinstance(p, tuple) and len(p) != 2:
133+
raise ValueError("If a tuple is provided in argument `prompts`, it must be of length 2")
134+
if isinstance(p, tuple):
135+
prompts_names.append(f"user_defined_{count}")
136+
prompts_list.extend(p)
137+
count += 1
138+
139+
return prompts_list, prompts_names
140+
141+
142+
def _clip_iqa_get_anchor_vectors(
143+
model_name_or_path: str,
144+
model: _CLIPModel,
145+
processor: _CLIPProcessor,
146+
prompts_list: List[str],
147+
device: Union[str, torch.device],
148+
) -> Tensor:
149+
"""Calculates the anchor vectors for the CLIP IQA metric.
150+
151+
Args:
152+
model_name_or_path: string indicating the version of the CLIP model to use.
153+
model: The CLIP model
154+
processor: The CLIP processor
155+
prompts_list: A list of prompts
156+
device: The device to use for the calculation
157+
158+
"""
159+
if model_name_or_path == "clip_iqa":
160+
text_processed = processor(text=prompts_list)
161+
anchors_text = torch.zeros(
162+
len(prompts_list), processor.tokenizer.model_max_length, dtype=torch.long, device=device
163+
)
164+
for i, tp in enumerate(text_processed["input_ids"]):
165+
anchors_text[i, : len(tp)] = torch.tensor(tp, dtype=torch.long, device=device)
166+
167+
anchors = model.encode_text(anchors_text).float()
168+
else:
169+
text_processed = processor(text=prompts_list, return_tensors="pt", padding=True)
170+
anchors = model.get_text_features(
171+
text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device)
172+
)
173+
return anchors / anchors.norm(p=2, dim=-1, keepdim=True)
174+
175+
176+
def _clip_iqa_update(
177+
model_name_or_path: str,
178+
images: Tensor,
179+
model: _CLIPModel,
180+
processor: _CLIPProcessor,
181+
data_range: Union[int, float],
182+
device: Union[str, torch.device],
183+
) -> Tensor:
184+
images = images / float(data_range)
185+
"""Update function for CLIP IQA."""
186+
if model_name_or_path == "clip_iqa":
187+
# default mean and std from clip paper, see:
188+
# https://github.com/huggingface/transformers/blob/main/src/transformers/utils/constants.py
189+
default_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
190+
default_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
191+
images = (images - default_mean) / default_std
192+
img_features = model.encode_image(images.float(), pos_embedding=False).float()
193+
else:
194+
processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True)
195+
img_features = model.get_image_features(processed_input["pixel_values"].to(device))
196+
return img_features / img_features.norm(p=2, dim=-1, keepdim=True)
197+
198+
199+
def _clip_iqa_compute(
200+
img_features: Tensor,
201+
anchors: Tensor,
202+
prompts_names: List[str],
203+
format_as_dict: bool = True,
204+
) -> Union[Tensor, Dict[str, Tensor]]:
205+
"""Final computation of CLIP IQA."""
206+
logits_per_image = 100 * img_features @ anchors.t()
207+
probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0]
208+
if len(prompts_names) == 1:
209+
return probs.squeeze()
210+
if format_as_dict:
211+
return {p: probs[:, i] for i, p in enumerate(prompts_names)}
212+
return probs
213+
214+
215+
def clip_image_quality_assessment(
216+
images: Tensor,
217+
model_name_or_path: Literal[
218+
"clip_iqa",
219+
"openai/clip-vit-base-patch16",
220+
"openai/clip-vit-base-patch32",
221+
"openai/clip-vit-large-patch14-336",
222+
"openai/clip-vit-large-patch14",
223+
] = "clip_iqa",
224+
data_range: Union[int, float] = 1.0,
225+
prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",),
226+
) -> Union[Tensor, Dict[str, Tensor]]:
227+
"""Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.
228+
229+
The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to
230+
be able to generate a vector representation of the image and the text that is similar if the image and text are
231+
semantically similar.
232+
233+
The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The
234+
prompts always come in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating
235+
the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine
236+
which prompt the image is more similar to. The metric then returns the probability that the image is more similar
237+
to the first prompt than the second prompt.
238+
239+
Build in promts are:
240+
* quality: "Good photo." vs "Bad photo."
241+
* brightness: "Bright photo." vs "Dark photo."
242+
* noisiness: "Clean photo." vs "Noisy photo."
243+
* colorfullness: "Colorful photo." vs "Dull photo."
244+
* sharpness: "Sharp photo." vs "Blurry photo."
245+
* contrast: "High contrast photo." vs "Low contrast photo."
246+
* complexity: "Complex photo." vs "Simple photo."
247+
* natural: "Natural photo." vs "Synthetic photo."
248+
* happy: "Happy photo." vs "Sad photo."
249+
* scary: "Scary photo." vs "Peaceful photo."
250+
* new: "New photo." vs "Old photo."
251+
* warm: "Warm photo." vs "Cold photo."
252+
* real: "Real photo." vs "Abstract photo."
253+
* beutiful: "Beautiful photo." vs "Ugly photo."
254+
* lonely: "Lonely photo." vs "Sociable photo."
255+
* relaxing: "Relaxing photo." vs "Stressful photo."
256+
257+
Args:
258+
images: Either a single ``[N, C, H, W]`` tensor or a list of ``[C, H, W]`` tensors
259+
model_name_or_path: string indicating the version of the CLIP model to use. By default this argument is set to
260+
``clip_iqa`` which corresponds to the model used in the original paper. Other availble models are
261+
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
262+
and `"openai/clip-vit-large-patch14"`
263+
data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255],
264+
data_range should be 255. The images are normalized by this value.
265+
prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
266+
of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one
267+
of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
268+
availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
269+
positive prompt and the second string must be a negative prompt.
270+
271+
.. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with
272+
`pip install piq` or `pip install torchmetrics[multimodal]`.
273+
274+
Returns:
275+
A tensor of shape ``(N,)`` if a single promts is provided. If a list of promts is provided, a dictionary of
276+
with the promts as keys and tensors of shape ``(N,)`` as values.
277+
278+
Raises:
279+
ModuleNotFoundError:
280+
If transformers package is not installed or version is lower than 4.10.0
281+
ValueError:
282+
If not all images have format [C, H, W]
283+
ValueError:
284+
If promts is a tuple and it is not of length 2
285+
ValueError:
286+
If promts is a string and it is not one of the available promts
287+
ValueError:
288+
If promts is a list of strings and not all strings are one of the available promts
289+
290+
Example::
291+
Single promt:
292+
293+
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
294+
>>> import torch
295+
>>> _ = torch.manual_seed(42)
296+
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
297+
>>> clip_image_quality_assessment(imgs, prompts=("quality",))
298+
tensor([0.8894, 0.8902])
299+
300+
Example::
301+
Multiple promts:
302+
303+
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
304+
>>> import torch
305+
>>> _ = torch.manual_seed(42)
306+
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
307+
>>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness"))
308+
{'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}
309+
310+
Example::
311+
Custom promts. Must always be a tuple of length 2, with a positive and negative prompt.
312+
313+
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
314+
>>> import torch
315+
>>> _ = torch.manual_seed(42)
316+
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
317+
>>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness"))
318+
{'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}
319+
320+
"""
321+
prompts_list, prompts_names = _clip_iqa_format_prompts(prompts)
322+
323+
model, processor = _get_clip_iqa_model_and_processor(model_name_or_path)
324+
device = images.device
325+
model = model.to(device)
326+
327+
with torch.inference_mode():
328+
anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device)
329+
img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device)
330+
return _clip_iqa_compute(img_features, anchors, prompts_names)

0 commit comments

Comments
 (0)