Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ gptqmodel = [
{ index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"},
{ index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"},
]
image-reward = { git = "https://github.com/PrunaAI/ImageReward" }

[project]
name = "pruna"
Expand Down Expand Up @@ -115,8 +116,9 @@ dependencies = [
"hqq==0.2.6",
"torchao",
"llmcompressor",
"gliner; python_version >= '3.10'"

"gliner; python_version >= '3.10'",
"image-reward",
"clip",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/pruna/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pruna.evaluation.metrics.metric_cmmd import CMMD
from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric
from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric
from pruna.evaluation.metrics.metric_imagereward import ImageRewardMetric
from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
Expand All @@ -37,4 +38,5 @@
"TotalMACsMetric",
"PairwiseClipScore",
"CMMD",
"ImageRewardMetric",
]
193 changes: 193 additions & 0 deletions src/pruna/evaluation/metrics/metric_imagereward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, List

import PIL
import torch
from torch import Tensor
from torchvision.transforms import ToPILImage

from pruna.engine.utils import set_to_best_available_device
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor
from pruna.logging.logger import pruna_logger

IMAGE_REWARD = "image_reward"


@MetricRegistry.register(IMAGE_REWARD)
class ImageRewardMetric(StatefulMetric):
"""
ImageReward metric for evaluating text-to-image generation quality.

ImageReward is a human preference reward model for text-to-image generation that
outperforms existing methods like CLIP, Aesthetic, and BLIP in understanding
human preferences.

Parameters
----------
device : str | torch.device | None, optional
The device to use for the model. If None, the best available device will be used.
model_name : str, optional
The ImageReward model to use. Default is "ImageReward-v1.0".
call_type : str
The type of call to use for the metric. IQA metrics, like image_reward, are only supported for single mode.
**kwargs : Any
Additional keyword arguments for the metric.
"""

higher_is_better: bool = True
default_call_type: str = "y"
metric_name: str = IMAGE_REWARD
metric_units: str = "score"

# Type annotations for dynamically added attributes
scores: List[float]
prompts: List[str]

def __init__(
self,
device: str | torch.device | None = None,
model_name: str = "ImageReward-v1.0",
call_type: str = SINGLE,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.device = set_to_best_available_device(device)
self.model_name = model_name
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)

# Import ImageReward here to avoid dependency issues
import ImageReward as RM # noqa: N814

# Load the ImageReward model
pruna_logger.info(f"Loading ImageReward model: {model_name}")
self.model = RM.load(model_name, device=self.device)
self.to_pil = ToPILImage()

# Initialize state for accumulating scores
self.add_state("scores", [])
self.add_state("prompts", [])

def update(self, x: List[str] | Tensor, gt: Tensor, outputs: Tensor) -> None:
"""
Update the metric with new batch data.

Parameters
----------
x : List[str] | Tensor
The input prompts for text-to-image generation.
gt : Tensor
The ground truth images (not used for ImageReward).
outputs : Tensor
The generated images to evaluate.
"""
# Prepare inputs
metric_inputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device)
prompts = self._extract_prompts(x)
images = metric_inputs[1] if len(metric_inputs) > 1 else outputs

# Format images as PIL Images
formatted_images = [self._format_image(image) for image in images]

# Score images with prompts
for prompt, image in zip(prompts, formatted_images):
score = self.model.score(prompt, image)
self.scores.append(score)
self.prompts.append(prompt)

def compute(self) -> MetricResult:
"""
Compute the final ImageReward metric.

Returns
-------
MetricResult
The computed ImageReward score.
"""
if not self.scores:
pruna_logger.warning("No scores available for ImageReward computation")
return MetricResult(self.metric_name, self.__dict__.copy(), 0.0)

# Calculate mean score
mean_score = torch.mean(torch.tensor(self.scores)).item()

return MetricResult(self.metric_name, self.__dict__.copy(), mean_score)

def _extract_prompts(self, x: List[str] | Tensor) -> List[str]:
"""
Extract prompts from input data.

Parameters
----------
x : List[str] | Tensor
The input data containing prompts.

Returns
-------
List[str]
The extracted prompts.
"""
if isinstance(x, list):
return x
elif isinstance(x, Tensor):
# If x is a tensor, we need to handle it differently
# This might be the case for some data formats
pruna_logger.warning("Input x is a tensor, assuming it contains encoded prompts")
return [f"prompt_{i}" for i in range(x.shape[0])]
else:
pruna_logger.error(f"Unexpected input type for prompts: {type(x)}")
return []

def _format_image(self, image: Tensor) -> PIL.Image.Image:
"""
Format a single image with its prompt using ImageReward.

Parameters
----------
image : Tensor
The image to score.

Returns
-------
float
The ImageReward score for the image.
"""
# Convert tensor to PIL Image
if image.dim() == 4:
# Batch dimension, take first image
image = image[0]

# Ensure image is in the correct format (C, H, W) with values in [0, 1]
if image.dim() == 3 and image.shape[0] in [1, 3, 4]:
# Image is in CHW format
if image.shape[0] == 1:
# Grayscale, convert to RGB
image = image.repeat(3, 1, 1)
elif image.shape[0] == 4:
# RGBA, take only RGB channels
image = image[:3]

# Normalize to [0, 1] if needed
if image.max() > 1.0:
image = image / 255.0

# Convert to PIL Image
pil_image = self.to_pil(image)
return pil_image
2 changes: 2 additions & 0 deletions src/pruna/evaluation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pruna.evaluation.metrics.metric_cmmd import CMMD
from pruna.evaluation.metrics.metric_elapsed_time import LATENCY, THROUGHPUT, TOTAL_TIME
from pruna.evaluation.metrics.metric_energy import CO2_EMISSIONS, ENERGY_CONSUMED
from pruna.evaluation.metrics.metric_imagereward import ImageRewardMetric
from pruna.evaluation.metrics.metric_memory import DISK_MEMORY
from pruna.evaluation.metrics.metric_model_architecture import TOTAL_MACS, TOTAL_PARAMS
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
Expand Down Expand Up @@ -199,6 +200,7 @@ def _process_single_request(request: str, device: str | torch.device | None) ->
TorchMetricWrapper("clip_score"),
TorchMetricWrapper("clip_score", call_type="pairwise"),
CMMD(device=device),
ImageRewardMetric(device=device),
]
else:
pruna_logger.error(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.")
Expand Down
82 changes: 82 additions & 0 deletions tests/evaluation/test_image_reward_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from PIL import Image

from pruna.evaluation.metrics.metric_imagereward import ImageRewardMetric, IMAGE_REWARD


def test_metric_registration():
"""Test that the metric is properly registered."""
from pruna.evaluation.metrics.registry import MetricRegistry

metric = MetricRegistry.get_metric(IMAGE_REWARD, device="cpu")
assert isinstance(metric, ImageRewardMetric)

def test_extract_prompts():
"""Test prompt extraction from different input types."""
metric = ImageRewardMetric(device="cpu")

# Test with list of strings
prompts = ["a beautiful sunset", "a cat playing"]
extracted = metric._extract_prompts(prompts)
assert extracted == prompts

# Test with tensor (should generate default prompts)
tensor = torch.randn(2, 3, 224, 224)
extracted = metric._extract_prompts(tensor)
assert len(extracted) == 2
assert all(prompt.startswith("prompt_") for prompt in extracted)


def test_score_image():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are really comprehensive thanks a lot! Shall we also add a case or cases using either PrunaModel.run_inference() or the EvaluationAgent, similar to our tests in cmmd metric? That way we would also ensure the metric is compatible with our engine and agent!

"""Test image scoring functionality."""
metric = ImageRewardMetric(device="cpu")

# Create a simple test image
image = torch.randn(3, 224, 224) # RGB image
prompt = "a beautiful landscape"

score = metric._score_image(prompt, image)
assert isinstance(score, float)
# Score should be a reasonable value (ImageReward typically outputs scores around 0-10)
assert -10 <= score <= 10


def test_update_and_compute():
"""Test the update and compute methods."""
metric = ImageRewardMetric(device="cpu")

# Create test data
prompts = ["a beautiful sunset", "a cat playing"]
images = torch.randn(2, 3, 224, 224) # 2 RGB images
gt_images = torch.randn(2, 3, 224, 224) # Ground truth images

# Update the metric
metric.update(prompts, gt_images, images)

# Compute the result
result = metric.compute()
import pdb; pdb.set_trace()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debugging Code Left in Test

A pdb.set_trace() call was left in test_update_and_compute. This debugging statement pauses test execution, which breaks automated test runs and CI/CD pipelines.

Fix in Cursor Fix in Web


def test_error_handling():
"""Test error handling for invalid inputs."""
metric = ImageRewardMetric(device="cpu")

# Test with invalid image shape
invalid_image = torch.randn(1, 1, 224) # Wrong shape
score = metric._score_image("test prompt", invalid_image)
assert score == 0.0 # Should return 0 for invalid inputs
Loading