Skip to content

Conversation

@davidberenstein1957
Copy link
Member

@davidberenstein1957 davidberenstein1957 commented Jul 22, 2025

Description

Implements hps hpsv, VQA and image reward as a single unified class with various adaptations per specific score metric.

Related Issue

Fixes #271
Fixes #270

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Comment on lines +63 to +64
image-reward = { git = "https://github.com/PrunaAI/ImageReward" }
hpsv2 = { git = "https://github.com/PrunaAI/HPSv2" }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issues with both repositories.
ImageReward, had pushed their fixes to main
HPSv2 had very strict constraints for protobuf where they set it to lower than 6

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnrachwan123 perhaps we should publish these on our own index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we could add t2v-metrics with a more relaxed Python constraint

@davidberenstein1957 davidberenstein1957 changed the title feat: 271 feature implement hps hpsv2 feat: 271 feature implement hps hpsv2 VQ Jul 23, 2025
@davidberenstein1957 davidberenstein1957 changed the title feat: 271 feature implement hps hpsv2 VQ feat: 271 feature implement hps hpsv2 VQA Jul 23, 2025
@davidberenstein1957 davidberenstein1957 force-pushed the feat/271-feature-implement-hps-hpsv2 branch 2 times, most recently from ee07fb4 to 78912da Compare July 24, 2025 08:27
Copy link
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to have these metrics in the package! I left some comment and @begum should also have a check :)

"PairwiseClipScore",
"CMMD",
"ImageRewardMetric",
"HPSMetric",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If HPS does not work, I would prefer to leave it out for now, or fix it.

"metric_cls, metric_name, model_load_kwargs",
METRIC_CLASSES_AND_NAMES,
)
class TestRewardMetrics:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not structure the tests in classes for now.

TorchMetricWrapper("clip_score"),
TorchMetricWrapper("clip_score", call_type="pairwise"),
CMMD(device=device),
ImageRewardMetric(device=device),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why HPS, HPSv2, and VQA are not added here?

self._load(**model_load_kwargs)

def _load(self, **kwargs: Any) -> None:
raise NotImplementedError("Subclasses must implement this method")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these? We cna probably use abstract class and methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for abstract classes if we are going to have a new metric type :)

The computed ImageReward score.
"""
if not self.scores:
pruna_logger.warning("No scores available for ImageReward computation")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It mentiones ImageReward here but it is for the Base class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

return MetricResult(self.metric_name, self.__dict__.copy(), 0.0)

# Calculate mean score
mean_score = torch.mean(torch.tensor(self.scores)).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this with @begumcig. Not necessary for now but there is the quesiton of adding a aggregation function option which would be mean by default.

VQA_REWARD = "vqa"


class BaseModelRewardMetric(StatefulMetric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this class specialize already for text to image, or it should be made more general in its implementation (e.g. there is notion of prompt and images in some methods)

# Preprocess image and move to device
images = preprocess(image).unsqueeze(0).to(device)
# Tokenize prompt and move to device
import clip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I also think that we already ahve some package capable of doing clip. Makes sense to check what is used in clip metric

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup!

@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

cursor[bot]

This comment was marked as outdated.

@github-actions github-actions bot removed the stale label Aug 21, 2025
* Introduced `ImageRewardMetric` to evaluate text-to-image generation quality, outperforming existing methods in understanding human preferences.
* Registered the new metric in the metrics registry and updated the relevant files to include it.
* Added `image-reward` and `clip` as dependencies in `pyproject.toml`.
* Implemented tests for the `ImageRewardMetric` to ensure functionality and robustness.

Co-authored-by: davidberenstein1957 <[email protected]>
* Introduced the ImageRewardMetric class to evaluate text-to-image generation quality, outperforming existing methods.
* Updated task.py to integrate the new metric and adjusted metric retrieval methods for improved clarity.
* Enhanced pyproject.toml with new dependencies for ImageReward functionality.
* Added unit tests for the ImageReward metric to ensure proper functionality and error handling.
* Eliminated `timm` from the list of dependencies to streamline the project requirements.
* This change helps in reducing unnecessary package bloat and potential compatibility issues.
….lock

* Changed the GitHub repository source for the image-reward dependency from THUDM to PrunaAI in both pyproject.toml and uv.lock files.
* Removed the timm>=1.0.0 dependency from pyproject.toml to streamline the dependency list.
* Introduced HPSMetric and HPSv2Metric classes to evaluate text-to-image generation quality.
* Updated pyproject.toml to include new dependencies: hpsv2 and args.
* Created metric_reward.py to implement reward metrics and integrated them into the evaluation framework.
* Added unit tests for the new metrics, covering registration, prompt extraction, scoring, and error handling.
* Removed obsolete test file for ImageRewardMetric.
* Removed unused reward constants from the metrics module to simplify imports.
* Refactored test cases for reward metrics to utilize a fixture for metric initialization, enhancing test organization and readability.
* Ensured all tests for metric registration, prompt extraction, scoring, and error handling are properly integrated with the new structure.
* Introduced VQAMetric class to evaluate the quality of text-to-image generation using Visual Question Answering.
* Updated metric_reward.py to include the new VQA metric and its scoring method.
* Enhanced pyproject.toml to add the t2v-metrics dependency required for VQA functionality.
* Refactored existing reward metrics to accommodate additional model loading parameters.
* Updated unit tests to include the new VQA metric, ensuring proper functionality and integration with the existing metrics framework.
…guments

* Updated the reward_metrics function to accept additional keyword arguments for model loading, improving flexibility in metric initialization.
* Adjusted the METRIC_CLASSES_AND_NAMES structure to include model_load_kwargs for each metric class.
* Modified test cases in test_reward_metric.py to accept model_load_kwargs, enhancing the flexibility of metric initialization in tests.
* Adjusted method signatures for test_metric_registration, test_extract_prompts, test_score_image, test_update_and_compute, and test_error_handling to incorporate model loading arguments.
…rdMetric import

* Removed the duplicate import of ImageRewardMetric in task.py to streamline the code and improve readability.
* Refactored the HPSv2Metric class to utilize functools.partial for model scoring, allowing for dynamic version handling.
* Adjusted the test_reward_metric.py to modify the way model loading arguments are passed, enhancing test flexibility.
* Introduced SharpnessMetric to the metrics module for enhanced evaluation capabilities.
* Updated the __all__ list to include SharpnessMetric, ensuring it is accessible for imports.
@davidberenstein1957 davidberenstein1957 force-pushed the feat/271-feature-implement-hps-hpsv2 branch from 6399c5a to f3fbce7 Compare August 30, 2025 04:32
@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

float
The score of the image.
"""
score = self.model(imgs_path=image, prompt=prompt)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Image Object Passed to Path Parameter

In HPSv2Metric._score_image, the hpsv2.score function receives a PIL.Image.Image object for its imgs_path parameter. This parameter name suggests a file path string is expected, which will likely cause a runtime error.

Fix in Cursor Fix in Web

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember some of the rewards metrics expecting a path rather than a image, could be beneficial to look into this warning

license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">=3.9,<3.13"
requires-python = ">=3.10,<3.13"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Python Version Mismatch in Package Metadata

The pyproject.toml declares Python 3.9 support in its classifiers, which conflicts with the requires-python field now set to >=3.10,<3.13. This creates a mismatch in the package metadata.

Fix in Cursor Fix in Web

@github-actions github-actions bot removed the stale label Oct 15, 2025
Copy link
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for such a thought-out implementation David, I really like the RewardsMetric structure. I left some comments mostly about the base class but the metrics themselves already look pretty good overall!

"pynvml",
"thop",
"timm",
"bitsandbytes; sys_platform != 'darwin' or platform_machine != 'arm64'",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we remove timm from the dependencies?

default_call_type: str = "y"
metric_units: str = "score"

# Type annotations for dynamically added attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slay

self._load(**model_load_kwargs)

def _load(self, **kwargs: Any) -> None:
raise NotImplementedError("Subclasses must implement this method")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for abstract classes if we are going to have a new metric type :)

"""
# Prepare inputs
metric_inputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device)
prompts = self._extract_prompts(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The job of metric_data_processor is to handle device casting and separation of the inputs required for the metric. I see that in the next line you are using x, which could lead to device casting problems since it's not coming from the metric_data_processor. If the metric also needs the inputs we should use a different call_type for the metric (like x_y or y_x), then extract x from what the metric_data_processor returns :)

# Prepare inputs
metric_inputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device)
prompts = self._extract_prompts(x)
images = metric_inputs[1] if len(metric_inputs) > 1 else outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again we shouldn't use the outputs, once the metric data processor is called

# Preprocess image and move to device
images = preprocess(image).unsqueeze(0).to(device)
# Tokenize prompt and move to device
import clip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup!

metric_name: str = HPSv2_REWARD

def _load(self, **kwargs: Any) -> None:
from functools import partial
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can import partial on the top of the file :)

float
The score of the image.
"""
score = self.model(imgs_path=image, prompt=prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember some of the rewards metrics expecting a path rather than a image, could be beneficial to look into this warning

tensor = torch.randn(2, 3, 224, 224)
extracted = metric._extract_prompts(tensor)
assert len(extracted) == 2
assert all(prompt.startswith("prompt_") for prompt in extracted)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again i am a bit confused about this use case 🥺. do we just return prompt: and then the tensor in string format?

score = metric._score_image(prompt, pil_image)
assert isinstance(score, float)
# Score should be a reasonable value (ImageReward/HPS typically outputs scores around 0-10)
assert -10 <= score <= 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the score actually be negative ?

@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Oct 26, 2025
@github-actions github-actions bot removed the stale label Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] implement hps hpsv2

5 participants