Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


class DummyLoss1(Loss):
def __init__(self, loss_fn, true_output, output_transform=lambda x: x):
super(DummyLoss1, self).__init__(loss_fn, output_transform=output_transform)
def __init__(self, loss_fn, true_output, output_transform=lambda x: x, device="cpu"):
super().__init__(loss_fn, output_transform=output_transform, device=device)
print(true_output)
self.true_output = true_output

Expand All @@ -30,23 +30,23 @@ def update(self, output):
assert output == self.true_output


def test_output_as_mapping_without_criterion_kwargs():
def test_output_as_mapping_without_criterion_kwargs(available_device):
y_pred = torch.tensor([[2.0], [-2.0]])
y = torch.zeros(2)
criterion_kwargs = {}

loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs))
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device)
state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {}}))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)


def test_output_as_mapping_with_criterion_kwargs():
def test_output_as_mapping_with_criterion_kwargs(available_device):
y_pred = torch.tensor([[2.0], [-2.0]])
y = torch.zeros(2)
criterion_kwargs = {"reduction": "sum"}

loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs))
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device)
state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {"reduction": "sum"}}))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)
Expand Down Expand Up @@ -79,8 +79,9 @@ def test_zero_div():


@pytest.mark.parametrize("criterion", [nll_loss, nn.NLLLoss()])
def test_compute(criterion):
loss = Loss(criterion)
def test_compute(criterion, available_device):
loss = Loss(criterion, device=available_device)
assert loss._device == torch.device(available_device)

y_pred, y, expected_loss = y_test_1()
loss.update((y_pred, y))
Expand All @@ -99,7 +100,7 @@ def test_non_averaging_loss():
loss.update((y_pred, y))


def test_gradient_based_loss():
def test_gradient_based_loss(available_device):
# Tests https://github.com/pytorch/ignite/issues/1674
x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True)
y_pred = x.mm(torch.randn(size=(3, 1)))
Expand All @@ -113,12 +114,14 @@ def loss_fn(y_pred, x):

return gradients.norm(2, dim=1).mean()

loss = Loss(loss_fn)
loss = Loss(loss_fn, device=available_device)
assert loss._device == torch.device(available_device)
loss.update((y_pred, x))


def test_kwargs_loss():
loss = Loss(nll_loss)
def test_kwargs_loss(available_device):
loss = Loss(nll_loss, device=available_device)
assert loss._device == torch.device(available_device)

y_pred, y, _ = y_test_1()
kwargs = {"weight": torch.tensor([0.1, 0.1, 0.1])}
Expand All @@ -127,8 +130,9 @@ def test_kwargs_loss():
assert_almost_equal(loss.compute(), expected_value)


def test_reset():
loss = Loss(nll_loss)
def test_reset(available_device):
loss = Loss(nll_loss, device=available_device)
assert loss._device == torch.device(available_device)

y_pred, y = y_test_3()
loss.update((y_pred, y))
Expand Down Expand Up @@ -194,8 +198,9 @@ def _test_distrib_accumulator_device(device, y_test_1):
), f"{type(loss._sum.device)}:{loss._sum.device} vs {type(metric_device)}:{metric_device}"


def test_sum_detached():
loss = Loss(nll_loss)
def test_sum_detached(available_device):
loss = Loss(nll_loss, device=available_device)
assert loss._device == torch.device(available_device)

y_pred, y, _ = y_test_1(requires_grad=True)
loss.update((y_pred, y))
Expand Down Expand Up @@ -330,8 +335,8 @@ def forward(


class DummyLoss3(Loss):
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False):
super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling)
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False, device="cpu"):
super().__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling, device=device)
self._expected_loss = expected_loss
self._loss_fn = loss_fn

Expand All @@ -347,7 +352,7 @@ def update(self, output):
assert calculated_loss == self._expected_loss


def test_skip_unrolling_loss():
def test_skip_unrolling_loss(available_device):
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
Expand All @@ -358,7 +363,9 @@ def test_skip_unrolling_loss():
multi_output_mse_loss = CustomMultiMSELoss()
expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true)

loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True)
loss_metric = DummyLoss3(
loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True, device=available_device
)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)