Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions configs/vision/pathology/offline/regression/tiger_til_score.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz add this config to tests/eva/vision/test_vision_cli.py (at least to test_configuration_initialization, ideally also to test_predict_fit_from_configuration), so we can test for instantiation errors

Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 20}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/tiger_til}
max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100}
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
init_args:
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: ${oc.env:SAVE_LAST, false}
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MAE}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

monitors val/MAE with mode: max, so checkpoints and early stopping will keep the worst models rather than the best.

- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: ${oc.env:PATIENCE, 20}
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.ClassificationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/tiger_til}
dataloader_idx_map:
0: train
1: val
2: test
metadata_keys: ["wsi_id"]
backbone:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
overwrite: false
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.HeadModule
init_args:
head:
class_path: eva.vision.models.networks.ABMIL
init_args:
input_size: ${oc.env:IN_FEATURES, 384}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set projected_input_size here? (ABMIL doesn't have a default value for that, so instantiation should fail without setting it)

criterion: torch.nn.MSELoss
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.001}
betas: [0.9, 0.999]
metrics:
common:
- class_path: eva.core.metrics.AverageLoss
- class_path: eva.core.metrics.RegressionMetrics
init_args:
prefix: null
postfix: null
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
embeddings_transforms:
class_path: eva.core.data.transforms.Pad2DTensor
init_args:
pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200}
target_transforms:
class_path: eva.vision.data.transforms.common.Squeeze
init_args:
dim: -1
val:
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
init_args:
<<: *DATASET_ARGS
split: test
predict:
- class_path: eva.vision.datasets.TIGERTILScore
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/training/wsitils}
sampler:
class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
init_args:
max_samples: *N_PATCHES
width: 224
height: 224
target_mpp: 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This init arg was removed from the tiger dataset class.

split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.TIGERTILScore
init_args:
<<: *PREDICT_DATASET_ARGS
split: val
- class_path: eva.vision.datasets.TIGERTILScore
init_args:
<<: *PREDICT_DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
test:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
num_workers: *N_DATA_WORKERS
6 changes: 6 additions & 0 deletions src/eva/core/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
MultiEmbeddingsClassificationDataset,
)
from eva.core.data.datasets.dataset import TorchDataset
from eva.core.data.datasets.regression import (
EmbeddingsRegressionDataset,
MultiEmbeddingsRegressionDataset,
)
from eva.core.data.datasets.typings import DataSample

__all__ = [
"Dataset",
"MapDataset",
"EmbeddingsClassificationDataset",
"MultiEmbeddingsClassificationDataset",
"EmbeddingsRegressionDataset",
"MultiEmbeddingsRegressionDataset",
"TorchDataset",
"DataSample",
]
108 changes: 7 additions & 101 deletions src/eva/core/data/datasets/classification/multi_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,16 @@
"""Dataset class for where a sample corresponds to multiple embeddings."""

import os
from typing import Callable, Dict, List, Literal
"""Dataset class for where a classification task sample corresponds to multiple embeddings."""

import numpy as np
import torch
from typing_extensions import override

from eva.core.data.datasets import embeddings as embeddings_base
from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset


class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
class MultiEmbeddingsClassificationDataset(MultiEmbeddingsDataset):
"""Dataset class for where a sample corresponds to multiple embeddings.

Example use case: Slide level dataset where each slide has multiple patch embeddings.
Specialised for classification data with an int target type.
"""

def __init__(
self,
root: str,
manifest_file: str,
split: Literal["train", "val", "test"],
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
embeddings_transforms: Callable | None = None,
target_transforms: Callable | None = None,
):
"""Initialize dataset.

Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.

The manifest must have a `column_mapping["multi_id"]` column that contains the
unique identifier group of embeddings. For oncology datasets, this would be usually
the slide id. Each row in the manifest file points to a .pt file that can contain
one or multiple embeddings (either as a list or stacked tensors). There can also be
multiple rows for the same `multi_id`, in which case the embeddings from the different
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.

Args:
root: Root directory of the dataset.
manifest_file: The path to the manifest file, which is relative to
the `root` argument.
split: The dataset split to use. The `split` column of the manifest
file will be splitted based on this value.
column_mapping: Defines the map between the variables and the manifest
columns. It will overwrite the `default_column_mapping` with
the provided values, so that `column_mapping` can contain only the
values which are altered or missing.
embeddings_transforms: A function/transform that transforms the embedding.
target_transforms: A function/transform that transforms the target.
"""
super().__init__(
manifest_file=manifest_file,
root=root,
split=split,
column_mapping=column_mapping,
embeddings_transforms=embeddings_transforms,
target_transforms=target_transforms,
)

self._multi_ids: List[int]

@override
def setup(self):
super().setup()
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())

@override
def load_embeddings(self, index: int) -> torch.Tensor:
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
# Get all embeddings for the given index (multi_id)
multi_id = self._multi_ids[index]
embedding_paths = self._data.loc[
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
].to_list()

# Load embeddings and stack them accross the first dimension
embeddings = []
for path in embedding_paths:
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
if isinstance(embedding, list):
embedding = torch.stack(embedding, dim=0)
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
embeddings = torch.cat(embeddings, dim=0)

if not embeddings.ndim == 2:
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")

return embeddings

@override
def load_target(self, index: int) -> np.ndarray:
"""Returns the target corresponding to the `index`'th multi_id.

This method assumes that all the embeddings corresponding to the same `multi_id`
have the same target. If this is not the case, it will raise an error.
"""
multi_id = self._multi_ids[index]
targets = self._data.loc[
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
]

if not targets.nunique() == 1:
raise ValueError(f"Multiple targets found for {multi_id}.")

return np.asarray(targets.iloc[0], dtype=np.int64)

@override
def __len__(self) -> int:
return len(self._multi_ids)
def __init__(self, *args, **kwargs):
"""Initialize dataset with the correct return type."""
super().__init__(*args, target_type=np.int64, **kwargs)
114 changes: 114 additions & 0 deletions src/eva/core/data/datasets/multi_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Dataset class for where a sample corresponds to multiple embeddings."""

import os
from typing import Any, Callable, Dict, List, Literal

import numpy as np
import numpy.typing as npt
import torch
from typing_extensions import override

from eva.core.data.datasets import embeddings as embeddings_base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: from eva.core.data.datasets import embeddings as base (for conciseness)



class MultiEmbeddingsDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
"""Dataset class for where a sample corresponds to multiple embeddings.

Example use case: Slide level dataset where each slide has multiple patch embeddings.
"""

def __init__(
self,
root: str,
manifest_file: str,
split: Literal["train", "val", "test"],
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
embeddings_transforms: Callable | None = None,
target_transforms: Callable | None = None,
target_type: type[np.generic] = np.int64,
):
"""Initialize dataset.

Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.

The manifest must have a `column_mapping["multi_id"]` column that contains the
unique identifier group of embeddings. For oncology datasets, this would be usually
the slide id. Each row in the manifest file points to a .pt file that can contain
one or multiple embeddings (either as a list or stacked tensors). There can also be
multiple rows for the same `multi_id`, in which case the embeddings from the different
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.

Args:
root: Root directory of the dataset.
manifest_file: The path to the manifest file, which is relative to
the `root` argument.
split: The dataset split to use. The `split` column of the manifest
file will be splitted based on this value.
column_mapping: Defines the map between the variables and the manifest
columns. It will overwrite the `default_column_mapping` with
the provided values, so that `column_mapping` can contain only the
values which are altered or missing.
embeddings_transforms: A function/transform that transforms the embedding.
target_transforms: A function/transform that transforms the target.
target_type: Desired type of the target data
"""
super().__init__(
manifest_file=manifest_file,
root=root,
split=split,
column_mapping=column_mapping,
embeddings_transforms=embeddings_transforms,
target_transforms=target_transforms,
)

self._multi_ids: List[int]
self._target_type = target_type

@override
def setup(self):
super().setup()
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())

@override
def load_embeddings(self, index: int) -> torch.Tensor:
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
# Get all embeddings for the given index (multi_id)
multi_id = self._multi_ids[index]
embedding_paths = self._data.loc[
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
].to_list()

# Load embeddings and stack them accross the first dimension
embeddings = []
for path in embedding_paths:
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
if isinstance(embedding, list):
embedding = torch.stack(embedding, dim=0)
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
embeddings = torch.cat(embeddings, dim=0)

if not embeddings.ndim == 2:
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")

return embeddings

@override
def load_target(self, index: int) -> npt.NDArray[Any]:
"""Returns the target corresponding to the `index`'th multi_id.

This method assumes that all the embeddings corresponding to the same `multi_id`
have the same target. If this is not the case, it will raise an error.
"""
multi_id = self._multi_ids[index]
targets = self._data.loc[
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
]

if not targets.nunique() == 1:
raise ValueError(f"Multiple targets found for {multi_id}.")

return np.asarray(targets.iloc[0], dtype=self._target_type)

@override
def __len__(self) -> int:
return len(self._multi_ids)
6 changes: 6 additions & 0 deletions src/eva/core/data/datasets/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Embedding regression datasets API."""

from eva.core.data.datasets.regression.embeddings import EmbeddingsRegressionDataset
from eva.core.data.datasets.regression.multi_embeddings import MultiEmbeddingsRegressionDataset

__all__ = ["EmbeddingsRegressionDataset", "MultiEmbeddingsRegressionDataset"]
Loading