Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/eva/core/models/transforms/extract_cls_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@ class ExtractCLSFeatures:
"""Extracts the CLS token from a ViT model output."""

def __init__(
self, cls_index: int = 0, num_register_tokens: int = 0, include_patch_tokens: bool = False
self,
cls_index: int = 0,
num_register_tokens: int = 0,
concat_mean_patch_tokens: bool = False,
) -> None:
"""Initializes the transformation.

Args:
cls_index: The index of the CLS token in the output tensor.
num_register_tokens: The number of register tokens in the model output.
include_patch_tokens: Whether to concat the mean aggregated patch tokens with
concat_mean_patch_tokens: Whether to concat the mean aggregated patch tokens with
the cls token.
"""
self._cls_index = cls_index
self._num_register_tokens = num_register_tokens
self._include_patch_tokens = include_patch_tokens
self._concat_mean_patch_tokens = concat_mean_patch_tokens

def __call__(
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
Expand All @@ -34,7 +37,7 @@ def __call__(
tensor = tensor.last_hidden_state

cls_token = tensor[:, self._cls_index, :]
if self._include_patch_tokens:
if self._concat_mean_patch_tokens:
patch_tokens = tensor[:, 1 + self._num_register_tokens :, :]
return torch.cat([cls_token, patch_tokens.mean(1)], dim=-1)

Expand Down
16 changes: 12 additions & 4 deletions src/eva/core/models/wrappers/from_torchhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def __init__(
self,
model_name: str,
repo_or_dir: str,
pretrained: bool = True,
checkpoint_path: str = "",
out_indices: int | Tuple[int, ...] | None = None,
norm: bool = False,
trust_repo: bool = True,
forward_features: bool = False,
model_kwargs: Dict[str, Any] | None = None,
tensor_transforms: Callable | None = None,
) -> None:
Expand All @@ -30,14 +30,15 @@ def __init__(
Args:
model_name: Name of model to instantiate.
repo_or_dir: The torch.hub repository or local directory to load the model from.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
checkpoint_path: Path of checkpoint to load.
out_indices: Returns last n blocks if `int`, all if `None`, select
matching indices if sequence.
norm: Wether to apply norm layer to all intermediate features. Only
used when `out_indices` is not `None`.
trust_repo: If set to `False`, a prompt will ask the user whether the
repo should be trusted.
forward_features: Use `forward_features` method instead of `forward` in
forward pass.
model_kwargs: Extra model arguments.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
Expand All @@ -46,11 +47,11 @@ def __init__(

self._model_name = model_name
self._repo_or_dir = repo_or_dir
self._pretrained = pretrained
self._checkpoint_path = checkpoint_path
self._out_indices = out_indices
self._norm = norm
self._trust_repo = trust_repo
self._forward_features = forward_features
self._model_kwargs = model_kwargs or {}

self.load_model()
Expand All @@ -62,7 +63,6 @@ def load_model(self) -> None:
repo_or_dir=self._repo_or_dir,
model=self._model_name,
trust_repo=self._trust_repo,
pretrained=self._pretrained,
**self._model_kwargs,
) # type: ignore

Expand All @@ -79,6 +79,8 @@ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tenso
"Only models with `get_intermediate_layers` are supported "
"when using `out_indices`."
)
if self._forward_features:
raise ValueError("`forward_features` is not supported when `out_indices` is enabled.")

return list(
self._model.get_intermediate_layers(
Expand All @@ -89,5 +91,11 @@ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tenso
norm=self._norm,
)
)
elif self._forward_features:
if not hasattr(self._model, "forward_features"):
raise ValueError(
f"`forward_features` method not available for model {type(self._model)}"
)
return self._model.forward_features(tensor)

return self._model(tensor)
14 changes: 9 additions & 5 deletions src/eva/vision/models/networks/backbones/pathology/bioptimus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

from typing import Tuple

import timm
from torch import nn

from eva.vision.models import wrappers
from eva.vision.models.networks.backbones.registry import register_model


@register_model("pathology/bioptimus_h_optimus_0")
def bioptimus_h_optimus_0(
dynamic_img_size: bool = True,
out_indices: int | Tuple[int, ...] | None = None,
concat_mean_patch_tokens: bool = False,
) -> nn.Module:
"""Initializes the h_optimus_0 pathology FM by Bioptimus.

Expand All @@ -20,15 +21,18 @@ def bioptimus_h_optimus_0(
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
concat_mean_patch_tokens: Concat the CLS token with mean aggregated patch tokens.

Returns:
The model instance.
"""
return timm.create_model(
return wrappers.TimmModel(
model_name="hf-hub:bioptimus/H-optimus-0",
pretrained=True,
init_values=1e-5,
dynamic_img_size=dynamic_img_size,
out_indices=out_indices,
features_only=out_indices is not None,
model_kwargs={
"dynamic_img_size": dynamic_img_size,
"init_values": 1e-5,
},
concat_mean_patch_tokens=concat_mean_patch_tokens,
)
12 changes: 8 additions & 4 deletions src/eva/vision/models/networks/backbones/pathology/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

from typing import Tuple

import timm
from torch import nn

from eva.vision.models import wrappers
from eva.vision.models.networks.backbones.registry import register_model


@register_model("pathology/prov_gigapath")
def prov_gigapath(
dynamic_img_size: bool = True,
out_indices: int | Tuple[int, ...] | None = None,
concat_mean_patch_tokens: bool = False,
) -> nn.Module:
"""Initializes the Prov-GigaPath pathology FM.

Expand All @@ -20,14 +21,17 @@ def prov_gigapath(
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
concat_mean_patch_tokens: Concat the CLS token with mean aggregated patch tokens.

Returns:
The model instance.
"""
return timm.create_model(
return wrappers.TimmModel(
model_name="hf_hub:prov-gigapath/prov-gigapath",
pretrained=True,
dynamic_img_size=dynamic_img_size,
out_indices=out_indices,
features_only=out_indices is not None,
model_kwargs={
"dynamic_img_size": dynamic_img_size,
},
concat_mean_patch_tokens=concat_mean_patch_tokens,
)
24 changes: 20 additions & 4 deletions src/eva/vision/models/networks/backbones/pathology/histai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


@register_model("pathology/histai_hibou_b")
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
def histai_hibou_b(
out_indices: int | Tuple[int, ...] | None = None, concat_mean_patch_tokens: bool = False
) -> nn.Module:
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).

Uses a customized implementation of the DINOv2 architecture from the transformers
Expand All @@ -18,20 +20,28 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
Args:
out_indices: Whether and which multi-level patch embeddings to return.
Currently only out_indices=1 is supported.
concat_mean_patch_tokens: Concat the CLS token with mean aggregated patch tokens.

Returns:
The model instance.
"""
transform_args = {"num_register_tokens": 4}
return _utils.load_hugingface_model(
model_name="histai/hibou-B",
out_indices=out_indices,
model_kwargs={"trust_remote_code": True},
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
transform_args=(
transform_args | {"concat_mean_patch_tokens": concat_mean_patch_tokens}
if concat_mean_patch_tokens
else transform_args
),
)


@register_model("pathology/histai_hibou_l")
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
def histai_hibou_l(
out_indices: int | Tuple[int, ...] | None = None, concat_mean_patch_tokens: bool = False
) -> nn.Module:
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).

Uses a customized implementation of the DINOv2 architecture from the transformers
Expand All @@ -40,13 +50,19 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
Args:
out_indices: Whether and which multi-level patch embeddings to return.
Currently only out_indices=1 is supported.
concat_mean_patch_tokens: Concat the CLS token with mean aggregated patch tokens.

Returns:
The model instance.
"""
transform_args = {"num_register_tokens": 4}
return _utils.load_hugingface_model(
model_name="histai/hibou-L",
out_indices=out_indices,
model_kwargs={"trust_remote_code": True},
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
transform_args=(
transform_args | {"concat_mean_patch_tokens": concat_mean_patch_tokens}
if concat_mean_patch_tokens
else transform_args
),
)
Loading
Loading