Skip to content

SMAPE() Expected all tensors to be on the same device #1352

@jacktang

Description

@jacktang
  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 1.12.1
  • Python version: 3.10.6
  • Operating System: Ubuntu 20.4

Expected behavior

The code from N-beats tutorial:

actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
baseline_predictions = Baseline().predict(val_dataloader)
SMAPE()(baseline_predictions, actuals)

expected to get result without error

Actual behavior

However, errors came up:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:446](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:446), in Metric._wrap_update..wrapped_func(*args, **kwargs)
    445 try:
--> 446     update(*args, **kwargs)
    447 except RuntimeError as err:

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/base_metrics.py:784](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/base_metrics.py:784), in MultiHorizonMetric.update(self, y_pred, target)
    782     lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device)
--> 784 losses = self.loss(y_pred, target)
    785 # weight samples

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/point.py:69](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/pytorch_forecasting/metrics/point.py:69), in SMAPE.loss(self, y_pred, target)
     68 y_pred = self.to_prediction(y_pred)
---> 69 loss = 2 * (y_pred - target).abs() [/](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/) (y_pred.abs() + target.abs() + 1e-8)
     70 return loss

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[10], line 5
      3 actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
      4 baseline_predictions = Baseline().predict(val_dataloader)
----> 5 SMAPE()(baseline_predictions, actuals)

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:290](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:290), in Metric.forward(self, *args, **kwargs)
    288     self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    289 else:
--> 290     self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    292 return self._forward_cache

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:357](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:357), in Metric._forward_reduce_state_update(self, *args, **kwargs)
    354 self._enable_grad = True  # allow grads for batch computation
    356 # calculate batch state and compute batch value
--> 357 self.update(*args, **kwargs)
    358 batch_val = self.compute()
    360 # reduce batch and global state

File [~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:449](https://vscode-remote+ssh-002dremote-002bd1-002eys.vscode-resource.vscode-cdn.net/home/dev/Codes/caebigdata-jack/torch%2Bjax/forecasting/~/miniconda3/envs/pf/lib/python3.10/site-packages/torchmetrics/metric.py:449), in Metric._wrap_update..wrapped_func(*args, **kwargs)
    447     except RuntimeError as err:
    448         if "Expected all tensors to be on" in str(err):
--> 449             raise RuntimeError(
    450                 "Encountered different devices in metric calculation (see stacktrace for details)."
    451                 " This could be due to the metric class not being on the same device as input."
    452                 f" Instead of `metric={self.__class__.__name__}(...)` try to do"
    453                 f" `metric={self.__class__.__name__}(...).to(device)` where"
    454                 " device corresponds to the device of the input."
    455             ) from err
    456         raise err
    458 if self.compute_on_cpu:

RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=SMAPE(...)` try to do `metric=SMAPE(...).to(device)` where device corresponds to the device of the input.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions