diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 0e945bec58cf..368ed347185d 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -15,8 +15,8 @@ class DummyLoss1(Loss): - def __init__(self, loss_fn, true_output, output_transform=lambda x: x): - super(DummyLoss1, self).__init__(loss_fn, output_transform=output_transform) + def __init__(self, loss_fn, true_output, output_transform=lambda x: x, device="cpu"): + super().__init__(loss_fn, output_transform=output_transform, device=device) print(true_output) self.true_output = true_output @@ -30,23 +30,23 @@ def update(self, output): assert output == self.true_output -def test_output_as_mapping_without_criterion_kwargs(): +def test_output_as_mapping_without_criterion_kwargs(available_device): y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) criterion_kwargs = {} - loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs)) + loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device) state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {}})) engine = MagicMock(state=state) loss_metric.iteration_completed(engine) -def test_output_as_mapping_with_criterion_kwargs(): +def test_output_as_mapping_with_criterion_kwargs(available_device): y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) criterion_kwargs = {"reduction": "sum"} - loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs)) + loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device) state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {"reduction": "sum"}})) engine = MagicMock(state=state) loss_metric.iteration_completed(engine) @@ -79,8 +79,9 @@ def test_zero_div(): @pytest.mark.parametrize("criterion", [nll_loss, nn.NLLLoss()]) -def test_compute(criterion): - loss = Loss(criterion) +def test_compute(criterion, available_device): + loss = Loss(criterion, device=available_device) + assert loss._device == torch.device(available_device) y_pred, y, expected_loss = y_test_1() loss.update((y_pred, y)) @@ -99,7 +100,7 @@ def test_non_averaging_loss(): loss.update((y_pred, y)) -def test_gradient_based_loss(): +def test_gradient_based_loss(available_device): # Tests https://github.com/pytorch/ignite/issues/1674 x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True) y_pred = x.mm(torch.randn(size=(3, 1))) @@ -113,12 +114,14 @@ def loss_fn(y_pred, x): return gradients.norm(2, dim=1).mean() - loss = Loss(loss_fn) + loss = Loss(loss_fn, device=available_device) + assert loss._device == torch.device(available_device) loss.update((y_pred, x)) -def test_kwargs_loss(): - loss = Loss(nll_loss) +def test_kwargs_loss(available_device): + loss = Loss(nll_loss, device=available_device) + assert loss._device == torch.device(available_device) y_pred, y, _ = y_test_1() kwargs = {"weight": torch.tensor([0.1, 0.1, 0.1])} @@ -330,8 +333,8 @@ def forward( class DummyLoss3(Loss): - def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False): - super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling) + def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False, device="cpu"): + super().__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling, device=device) self._expected_loss = expected_loss self._loss_fn = loss_fn @@ -347,7 +350,7 @@ def update(self, output): assert calculated_loss == self._expected_loss -def test_skip_unrolling_loss(): +def test_skip_unrolling_loss(available_device): a_pred = torch.rand(8, 1) b_pred = torch.rand(8, 1) y_pred = [a_pred, b_pred] @@ -358,7 +361,9 @@ def test_skip_unrolling_loss(): multi_output_mse_loss = CustomMultiMSELoss() expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true) - loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True) + loss_metric = DummyLoss3( + loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True, device=available_device + ) state = State(output=(y_pred, y_true)) engine = MagicMock(state=state) loss_metric.iteration_completed(engine)