Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions configs/vision/radiology/online/segmentation/btcv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,19 @@ model:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
- class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
class_path: torchmetrics.segmentation.DiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
average: none
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
Expand Down
8 changes: 3 additions & 5 deletions configs/vision/radiology/online/segmentation/lits17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,19 @@ model:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
- class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
class_path: torchmetrics.segmentation.DiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
average: none
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
Expand Down
8 changes: 3 additions & 5 deletions configs/vision/radiology/online/segmentation/lits17_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,19 @@ model:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
- class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
class_path: torchmetrics.segmentation.DiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
average: none
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,19 @@ model:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
- class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
class_path: torchmetrics.segmentation.DiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
average: none
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
Expand Down
6 changes: 5 additions & 1 deletion src/eva/vision/metrics/defaults/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Default metric collections API."""

from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
from eva.vision.metrics.defaults.segmentation import (
MulticlassSegmentationMetrics,
MulticlassSegmentationMetricsV2,
)

__all__ = [
"MulticlassSegmentationMetrics",
"MulticlassSegmentationMetricsV2",
]
7 changes: 5 additions & 2 deletions src/eva/vision/metrics/defaults/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Default segmentation metric collections API."""

from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
from eva.vision.metrics.defaults.segmentation.multiclass import (
MulticlassSegmentationMetrics,
MulticlassSegmentationMetricsV2,
)

__all__ = ["MulticlassSegmentationMetrics"]
__all__ = ["MulticlassSegmentationMetrics", "MulticlassSegmentationMetricsV2"]
56 changes: 55 additions & 1 deletion src/eva/vision/metrics/defaults/segmentation/multiclass.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Default metric collection for multiclass semantic segmentation tasks."""

from typing import Literal

from eva.core.metrics import structs
from eva.core.utils import requirements
from eva.vision.metrics import segmentation


class MulticlassSegmentationMetrics(structs.MetricCollection):
"""Default metrics for multi-class semantic segmentation tasks."""
"""Metrics for multi-class semantic segmentation tasks."""

def __init__(
self,
Expand Down Expand Up @@ -66,3 +69,54 @@ def __init__(
prefix=prefix,
postfix=postfix,
)


class MulticlassSegmentationMetricsV2(structs.MetricCollection):
"""Metrics for multi-class semantic segmentation tasks.

In torchmetrics 1.8.0, the DiceScore implementation has been
improved, and should now provide enough signal. Therefore,
removing the monai implementation and iou for simplicity and
computational efficiency.
"""

def __init__(
self,
num_classes: int,
include_background: bool = False,
prefix: str | None = None,
postfix: str | None = None,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Initializes the multi-class semantic segmentation metrics.

Args:
num_classes: Integer specifying the number of classes.
include_background: Whether to include the background class in the metrics computation.
prefix: A string to add before the keys in the output dictionary.
postfix: A string to add after the keys in the output dictionary.
input_format: Input tensor format. Options are `"one-hot"` for one-hot encoded tensors,
`"index"` for index tensors.
"""
requirements.check_dependencies(requirements={"torchmetrics": "1.8.0"})
super().__init__(
metrics={
"DiceScore (macro)": segmentation.DiceScore(
num_classes=num_classes,
include_background=include_background,
average="macro",
aggregation_level="samplewise",
input_format=input_format,
),
"DiceScore (macro/global)": segmentation.DiceScore(
num_classes=num_classes,
include_background=include_background,
average="macro",
aggregation_level="global",
input_format=input_format,
),
},
prefix=prefix,
postfix=postfix,
)
self.num_classes = num_classes
39 changes: 19 additions & 20 deletions src/eva/vision/metrics/segmentation/dice.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
"""Generalized Dice Score metric for semantic segmentation."""
"""Dice Score metric for semantic segmentation."""

from typing import Any, Literal

import torch
from torchmetrics import segmentation
from torchmetrics.functional.segmentation.dice import _dice_score_update
from typing_extensions import override

from eva.vision.metrics.segmentation import _utils


class DiceScore(segmentation.DiceScore):
"""Defines the Generalized Dice Score.
"""Dice Score metric for semantic segmentation tasks.

It expands the `torchmetrics` class by including an `ignore_index`
functionality and converting tensors to one-hot format.
This implementation expands the `torchmetrics` class by including
an `ignore_index` functionality and converting tensors to one-hot
format on the fly.
"""

def __init__(
self,
num_classes: int,
include_background: bool = True,
average: Literal["micro", "macro", "weighted", "none"] | None = "micro",
input_format: Literal["one-hot", "index", "auto"] = "auto",
ignore_index: int | None = None,
**kwargs: Any,
) -> None:
Expand All @@ -32,6 +33,9 @@ def __init__(
include_background: Whether to include the background class in the computation
average: The method to average the dice score accross the different classes. Options are
`"micro"`, `"macro"`, `"weighted"`, `"none"` or `None`.
input_format: Input tensor format. Options are `"one-hot"` for one-hot encoded tensors,
`"index"` for index tensors, or `"auto"` to automatically convert the format
to one-hot.
ignore_index: Integer specifying a target class to ignore. If given, this class
index does not contribute to the returned score, regardless of reduction method.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -42,28 +46,23 @@ def __init__(
+ (ignore_index == 0 and not include_background),
include_background=include_background,
average=average,
input_format="one-hot",
input_format=input_format if input_format == "index" else "one-hot",
**kwargs,
)
self.orig_num_classes = num_classes
self.ignore_index = ignore_index
self.input_format = input_format

@override
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
if self.input_format == "auto":
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
if self.ignore_index is not None:
if self.input_format == "index":
raise ValueError(
"When `ignore_index` is set, `input_format` must be 'one-hot' or 'auto'."
)
preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)

# TODO: Replace _update by super.update() once the following issue is fixed:
# https://github.com/Lightning-AI/torchmetrics/issues/2847
self._update(preds.long(), target.long())
# super().update(preds=preds.long(), target=target.long())

def _update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
numerator, denominator, support = _dice_score_update(
preds, target, self.num_classes, self.include_background, self.input_format # type: ignore
)
self.numerator.append(numerator)
self.denominator.append(denominator)
self.support.append(support)
super().update(preds.long(), target.long())
30 changes: 25 additions & 5 deletions tests/eva/core/metrics/core/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,39 @@
from typing import List

import pytest
import torchmetrics
import torchmetrics.segmentation

from eva.core.metrics import structs

NUM_CLASSES = 3


@pytest.mark.parametrize(
"schema, expected",
[
(structs.MetricsSchema(train=torchmetrics.Dice()), [1, 0, 0]),
(structs.MetricsSchema(evaluation=torchmetrics.Dice()), [0, 1, 1]),
(structs.MetricsSchema(common=torchmetrics.Dice()), [1, 1, 1]),
(
structs.MetricsSchema(train=torchmetrics.Dice(), evaluation=torchmetrics.Dice()),
structs.MetricsSchema(
train=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES)
),
[1, 0, 0],
),
(
structs.MetricsSchema(
evaluation=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES)
),
[0, 1, 1],
),
(
structs.MetricsSchema(
common=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES)
),
[1, 1, 1],
),
(
structs.MetricsSchema(
train=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES),
evaluation=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES),
),
[1, 1, 1],
),
],
Expand Down
13 changes: 7 additions & 6 deletions tests/eva/core/metrics/core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torchmetrics
import torchmetrics.segmentation

from eva.core.metrics import structs
from eva.core.metrics.structs.typings import MetricModuleType
Expand Down Expand Up @@ -33,23 +34,23 @@
),
(
torchmetrics.Accuracy("binary"),
torchmetrics.Dice(),
torchmetrics.segmentation.DiceScore(num_classes=2),
None,
"[BinaryAccuracy(), Dice()]",
"[BinaryAccuracy(), DiceScore()]",
"BinaryAccuracy()",
),
(
torchmetrics.Accuracy("binary"),
None,
torchmetrics.Dice(),
torchmetrics.segmentation.DiceScore(num_classes=2),
"BinaryAccuracy()",
"[BinaryAccuracy(), Dice()]",
"[BinaryAccuracy(), DiceScore()]",
),
(
torchmetrics.Accuracy("binary"),
torchmetrics.Dice(),
torchmetrics.segmentation.DiceScore(num_classes=2),
torchmetrics.AUROC("binary"),
"[BinaryAccuracy(), Dice()]",
"[BinaryAccuracy(), DiceScore()]",
"[BinaryAccuracy(), BinaryAUROC()]",
),
],
Expand Down
Loading
Loading