-
Notifications
You must be signed in to change notification settings - Fork 68
feat: 271 feature implement hps hpsv2 VQA #272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| image-reward = { git = "https://github.com/PrunaAI/ImageReward" } | ||
| hpsv2 = { git = "https://github.com/PrunaAI/HPSv2" } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were some issues with both repositories.
ImageReward, had pushed their fixes to main
HPSv2 had very strict constraints for protobuf where they set it to lower than 6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@johnrachwan123 perhaps we should publish these on our own index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also we could add t2v-metrics with a more relaxed Python constraint
ee07fb4 to
78912da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking forward to have these metrics in the package! I left some comment and @begum should also have a check :)
| "PairwiseClipScore", | ||
| "CMMD", | ||
| "ImageRewardMetric", | ||
| "HPSMetric", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If HPS does not work, I would prefer to leave it out for now, or fix it.
| "metric_cls, metric_name, model_load_kwargs", | ||
| METRIC_CLASSES_AND_NAMES, | ||
| ) | ||
| class TestRewardMetrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not structure the tests in classes for now.
| TorchMetricWrapper("clip_score"), | ||
| TorchMetricWrapper("clip_score", call_type="pairwise"), | ||
| CMMD(device=device), | ||
| ImageRewardMetric(device=device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why HPS, HPSv2, and VQA are not added here?
| self._load(**model_load_kwargs) | ||
|
|
||
| def _load(self, **kwargs: Any) -> None: | ||
| raise NotImplementedError("Subclasses must implement this method") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these? We cna probably use abstract class and methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for abstract classes if we are going to have a new metric type :)
| The computed ImageReward score. | ||
| """ | ||
| if not self.scores: | ||
| pruna_logger.warning("No scores available for ImageReward computation") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It mentiones ImageReward here but it is for the Base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
| return MetricResult(self.metric_name, self.__dict__.copy(), 0.0) | ||
|
|
||
| # Calculate mean score | ||
| mean_score = torch.mean(torch.tensor(self.scores)).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this with @begumcig. Not necessary for now but there is the quesiton of adding a aggregation function option which would be mean by default.
| VQA_REWARD = "vqa" | ||
|
|
||
|
|
||
| class BaseModelRewardMetric(StatefulMetric): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either this class specialize already for text to image, or it should be made more general in its implementation (e.g. there is notion of prompt and images in some methods)
| # Preprocess image and move to device | ||
| images = preprocess(image).unsqueeze(0).to(device) | ||
| # Tokenize prompt and move to device | ||
| import clip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I also think that we already ahve some package capable of doing clip. Makes sense to check what is used in clip metric
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup!
|
This PR has been inactive for 10 days and is now marked as stale. |
* Introduced `ImageRewardMetric` to evaluate text-to-image generation quality, outperforming existing methods in understanding human preferences. * Registered the new metric in the metrics registry and updated the relevant files to include it. * Added `image-reward` and `clip` as dependencies in `pyproject.toml`. * Implemented tests for the `ImageRewardMetric` to ensure functionality and robustness. Co-authored-by: davidberenstein1957 <[email protected]>
* Introduced the ImageRewardMetric class to evaluate text-to-image generation quality, outperforming existing methods. * Updated task.py to integrate the new metric and adjusted metric retrieval methods for improved clarity. * Enhanced pyproject.toml with new dependencies for ImageReward functionality. * Added unit tests for the ImageReward metric to ensure proper functionality and error handling.
* Eliminated `timm` from the list of dependencies to streamline the project requirements. * This change helps in reducing unnecessary package bloat and potential compatibility issues.
….lock * Changed the GitHub repository source for the image-reward dependency from THUDM to PrunaAI in both pyproject.toml and uv.lock files. * Removed the timm>=1.0.0 dependency from pyproject.toml to streamline the dependency list.
* Introduced HPSMetric and HPSv2Metric classes to evaluate text-to-image generation quality. * Updated pyproject.toml to include new dependencies: hpsv2 and args. * Created metric_reward.py to implement reward metrics and integrated them into the evaluation framework. * Added unit tests for the new metrics, covering registration, prompt extraction, scoring, and error handling. * Removed obsolete test file for ImageRewardMetric.
* Removed unused reward constants from the metrics module to simplify imports. * Refactored test cases for reward metrics to utilize a fixture for metric initialization, enhancing test organization and readability. * Ensured all tests for metric registration, prompt extraction, scoring, and error handling are properly integrated with the new structure.
* Introduced VQAMetric class to evaluate the quality of text-to-image generation using Visual Question Answering. * Updated metric_reward.py to include the new VQA metric and its scoring method. * Enhanced pyproject.toml to add the t2v-metrics dependency required for VQA functionality. * Refactored existing reward metrics to accommodate additional model loading parameters. * Updated unit tests to include the new VQA metric, ensuring proper functionality and integration with the existing metrics framework.
…guments * Updated the reward_metrics function to accept additional keyword arguments for model loading, improving flexibility in metric initialization. * Adjusted the METRIC_CLASSES_AND_NAMES structure to include model_load_kwargs for each metric class.
* Modified test cases in test_reward_metric.py to accept model_load_kwargs, enhancing the flexibility of metric initialization in tests. * Adjusted method signatures for test_metric_registration, test_extract_prompts, test_score_image, test_update_and_compute, and test_error_handling to incorporate model loading arguments.
…rdMetric import * Removed the duplicate import of ImageRewardMetric in task.py to streamline the code and improve readability.
* Refactored the HPSv2Metric class to utilize functools.partial for model scoring, allowing for dynamic version handling. * Adjusted the test_reward_metric.py to modify the way model loading arguments are passed, enhancing test flexibility.
* Introduced SharpnessMetric to the metrics module for enhanced evaluation capabilities. * Updated the __all__ list to include SharpnessMetric, ensuring it is accessible for imports.
6399c5a to
f3fbce7
Compare
|
This PR has been inactive for 10 days and is now marked as stale. |
| float | ||
| The score of the image. | ||
| """ | ||
| score = self.model(imgs_path=image, prompt=prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember some of the rewards metrics expecting a path rather than a image, could be beneficial to look into this warning
| license = {file = "LICENSE"} | ||
| readme = "README.md" | ||
| requires-python = ">=3.9,<3.13" | ||
| requires-python = ">=3.10,<3.13" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for such a thought-out implementation David, I really like the RewardsMetric structure. I left some comments mostly about the base class but the metrics themselves already look pretty good overall!
| "pynvml", | ||
| "thop", | ||
| "timm", | ||
| "bitsandbytes; sys_platform != 'darwin' or platform_machine != 'arm64'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we remove timm from the dependencies?
| default_call_type: str = "y" | ||
| metric_units: str = "score" | ||
|
|
||
| # Type annotations for dynamically added attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slay
| self._load(**model_load_kwargs) | ||
|
|
||
| def _load(self, **kwargs: Any) -> None: | ||
| raise NotImplementedError("Subclasses must implement this method") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for abstract classes if we are going to have a new metric type :)
| """ | ||
| # Prepare inputs | ||
| metric_inputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) | ||
| prompts = self._extract_prompts(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The job of metric_data_processor is to handle device casting and separation of the inputs required for the metric. I see that in the next line you are using x, which could lead to device casting problems since it's not coming from the metric_data_processor. If the metric also needs the inputs we should use a different call_type for the metric (like x_y or y_x), then extract x from what the metric_data_processor returns :)
| # Prepare inputs | ||
| metric_inputs = metric_data_processor(x, gt, outputs, self.call_type, device=self.device) | ||
| prompts = self._extract_prompts(x) | ||
| images = metric_inputs[1] if len(metric_inputs) > 1 else outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again we shouldn't use the outputs, once the metric data processor is called
| # Preprocess image and move to device | ||
| images = preprocess(image).unsqueeze(0).to(device) | ||
| # Tokenize prompt and move to device | ||
| import clip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup!
| metric_name: str = HPSv2_REWARD | ||
|
|
||
| def _load(self, **kwargs: Any) -> None: | ||
| from functools import partial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we can import partial on the top of the file :)
| float | ||
| The score of the image. | ||
| """ | ||
| score = self.model(imgs_path=image, prompt=prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember some of the rewards metrics expecting a path rather than a image, could be beneficial to look into this warning
| tensor = torch.randn(2, 3, 224, 224) | ||
| extracted = metric._extract_prompts(tensor) | ||
| assert len(extracted) == 2 | ||
| assert all(prompt.startswith("prompt_") for prompt in extracted) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again i am a bit confused about this use case 🥺. do we just return prompt: and then the tensor in string format?
| score = metric._score_image(prompt, pil_image) | ||
| assert isinstance(score, float) | ||
| # Score should be a reasonable value (ImageReward/HPS typically outputs scores around 0-10) | ||
| assert -10 <= score <= 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can the score actually be negative ?
|
This PR has been inactive for 10 days and is now marked as stale. |
Description
Implements hps hpsv, VQA and image reward as a single unified class with various adaptations per specific score metric.
Related Issue
Fixes #271
Fixes #270
Type of Change
How Has This Been Tested?
Checklist
Additional Notes