diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 86c333adb3..66273cbcf9 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -103,14 +103,14 @@ from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled if check_if_transformers_greater("4.39"): - from transformers.utils import is_torch_xla_available + from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available - if is_torch_xla_available(): + if is_torch_tpu_xla_available(): import torch_xla.core.xla_model as xm else: - from transformers.utils import is_torch_tpu_available + from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available - if is_torch_tpu_available(check_device=False): + if is_torch_tpu_xla_available(check_device=False): import torch_xla.core.xla_model as xm if TYPE_CHECKING: @@ -735,7 +735,7 @@ def get_dataloader_sampler(dataloader): if ( args.logging_nan_inf_filter - and not is_torch_tpu_available() + and not is_torch_tpu_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses