diff --git a/lag_llama/gluon/estimator.py b/lag_llama/gluon/estimator.py index d0c5b52..d6aa1d9 100644 --- a/lag_llama/gluon/estimator.py +++ b/lag_llama/gluon/estimator.py @@ -477,7 +477,7 @@ def create_predictor( prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, - device="cuda" if torch.cuda.is_available() else "cpu", + device=self.device, ) else: return PyTorchPredictor( @@ -486,5 +486,5 @@ def create_predictor( prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, - device="cuda" if torch.cuda.is_available() else "cpu", + device=self.device, )