Skip to content

Commit 78912da

Browse files
refactor: update HPSv2Metric to use partial function for model scoring
* Refactored the HPSv2Metric class to utilize functools.partial for model scoring, allowing for dynamic version handling. * Adjusted the test_reward_metric.py to modify the way model loading arguments are passed, enhancing test flexibility.
1 parent 495d804 commit 78912da

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/pruna/evaluation/metrics/metric_reward.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,11 @@ class HPSv2Metric(BaseModelRewardMetric):
343343
metric_name: str = HPSv2_REWARD
344344

345345
def _load(self, **kwargs: Any) -> None:
346+
from functools import partial
347+
346348
import hpsv2
347349

348-
self.model = hpsv2
350+
self.model = partial(hpsv2.score, hps_version=kwargs.get("hps_version", "v2.1"))
349351

350352
def _score_image(self, prompt: str, image: PIL.Image.Image) -> float:
351353
"""
@@ -363,7 +365,7 @@ def _score_image(self, prompt: str, image: PIL.Image.Image) -> float:
363365
float
364366
The score of the image.
365367
"""
366-
score = self.model.score(imgs_path=image, prompt=prompt, hps_version="v2.1")
368+
score = self.model(imgs_path=image, prompt=prompt)
367369
# Handle case where score might be a list or array
368370
if isinstance(score, (list, tuple)):
369371
return float(score[0])

tests/evaluation/test_reward_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
(ImageRewardMetric, IMAGE_REWARD, {}),
3232
(HPSMetric, HPS_REWARD, {}),
3333
(HPSv2Metric, HPSv2_REWARD, {"hps_version": "v2.1"}),
34-
(VQAMetric, VQA_REWARD, {"model": "clip-flant5-xl"}),
34+
# (VQAMetric, VQA_REWARD, {"model": "clip-flant5-xl"}), # custom very large architecture
3535
]
3636

3737

@@ -65,7 +65,7 @@ def test_metric_registration(self, metric_cls, metric_name, model_load_kwargs):
6565
metric_name,
6666
device=set_to_best_available_device(device=None),
6767
call_type="y",
68-
**model_load_kwargs,
68+
model_load_kwargs=model_load_kwargs,
6969
)
7070
assert isinstance(metric, metric_cls)
7171

0 commit comments

Comments
 (0)