Skip to content

Conversation

@yitongh
Copy link
Contributor

@yitongh yitongh commented Feb 28, 2024

What does this PR do?

Make torch xla available on GPU. Currently, torch xla can be used in a GPU environment, but there are some conflicts between XLA and native PyTorch CUDA when using an environment with torch xla installed. This PR introduces the environment variable USE_TORCH_XLA to address this issue. When USE_TORCH_XLA is set to false, native PyTorch CUDA can be used seamlessly, even if torch xla is installed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr and @pacman100

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

The main changes:

  1. Change is_torch_tpu_available to is_torch_xla_available
  2. Change require_torch_tpu to require_torch_xla
  3. Add USE_TORCH_XLA to enable or disable torch_xla
  4. Fix amp check
  5. Move grad_norm.item() into _maybe_log_save_evaluate to prevent a performance degradation in XLA
  6. Copy the xla_fsdp_config to avoid modifying the original config

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

This PR is related to huggingface/accelerate#2176 and huggingface/accelerate#2467.

@will-cromar Could you please check if this PR has any impact on the TPU environment? Thanks.

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

The ci only failed in tests/test_modeling_utils.py::ModelUtilsTest::test_use_safetensors. I run this test in master and it also hangs. From pystack, it looks like this issue related to ssl read. stack log: https://gist.github.com/yitongh/34dc9c9f3de79d208533964bd63bb6f5

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yitongh
Copy link
Contributor Author

yitongh commented Mar 5, 2024

@muellerzr, would you be available to take a look at this PR when you have a moment? Alternatively, if you're not, perhaps you could suggest someone else who might be suited to review it? Many thanks.

@ArthurZucker ArthurZucker requested a review from muellerzr March 7, 2024 11:34
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is fine by me, we have very similar logic in Accelerate if I'm not mistaken. Thanks!

@muellerzr
Copy link
Contributor

Let's make sure we can fix those failing tests though, can you try rebasing from main?

@yitongh
Copy link
Contributor Author

yitongh commented Mar 8, 2024

@muellerzr I have rebased from main. I rerun the failing in my machine, both main and this pr passed test_run_ner_no_trainer and test_run_squad_no_trainer, but failed at test_run_glue_no_trainer. It looks like not related to this pr.

pytest -s -v examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_ner_no_trainer examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_squad_no_trainer examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer
======================================================================================================== test session starts =========================================================================================================
platform linux -- Python 3.10.12, pytest-8.0.0, pluggy-1.4.0 -- /usr/bin/python3.10
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/root/hyt/github/transformers/.hypothesis/examples'))
rootdir: /root/hyt/github/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.17, xdist-3.5.0, subtests-0.12.1, anyio-4.3.0, timeout-2.3.1
collected 3 items

examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_ner_no_trainer PASSED
examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_squad_no_trainer PASSED
examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer FAILED

============================================================================================================== FAILURES ==============================================================================================================
__________________________________________________________________________________________ ExamplesTestsNoTrainer.test_run_glue_no_trainer ___________________________________________________________________________________________

self = <test_accelerate_examples.ExamplesTestsNoTrainer testMethod=test_run_glue_no_trainer>

    @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
    def test_run_glue_no_trainer(self):
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            {self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
            --model_name_or_path distilbert/distilbert-base-uncased
            --output_dir {tmp_dir}
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --learning_rate=1e-4
            --seed=42
            --num_warmup_steps=2
            --checkpointing_steps epoch
            --with_tracking
        """.split()

        run_command(self._launch_args + testargs)
        result = get_results(tmp_dir)
>       self.assertGreaterEqual(result["eval_accuracy"], 0.75)
E       AssertionError: 0.6666666666666666 not greater than or equal to 0.75

examples/pytorch/test_accelerate_examples.py:98: AssertionError
========================================================================================================== warnings summary ==========================================================================================================
../../../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1394
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1394: PytestConfigWarning: Unknown config option: doctest_glob

    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================================== short test summary info =======================================================================================================
FAILED examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer - AssertionError: 0.6666666666666666 not greater than or equal to 0.75
========================================================================================= 1 failed, 2 passed, 1 warning in 143.07s (0:02:23) =========================================================================================

@muellerzr
Copy link
Contributor

They look to be timeout issues, I'm rerunning the tests now. However if they still fail they were not failing before this 😅

@muellerzr
Copy link
Contributor

Tests pass on our CI so looks to be fine

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @amyeroberts for final review

@muellerzr muellerzr requested a review from amyeroberts March 8, 2024 16:00
Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Mostly just small comments about the deprecation handling.

Main concern is that previously check_device was True by default. Therefore, replacing is_torch_tpu_available() with is_torch_xla_available() isn't an equivalent call.

return True
except RuntimeError:
return False
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final return should remain

Suggested change
return True
return True
return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to leave this as importable whilst it's still going through the deprecation cycle

Suggested change
is_torch_xla_available,
is_torch_tpu_available,
is_torch_xla_available,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't equivalent, previously, we were checking for a device, but by default that isn't happening anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to align with PR of the accelerate library. If users do not wish to use torch_xla in an environment where torch_xla is installed, they can configure it using USE_TORCH_XLA, which is also the purpose of this PR.

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change doesn't seem to have anything to do with the goal of this pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification is because tensor evaluation (grad_norm.item()) will cause XLA to execute the entire computation graph prematurely, resulting in decreased performance. The grad_norm.item() operation should be performed after the XLA mark_step.

grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
grad_norm = _grad_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep whilst it's still being deprecated

Suggested change
"is_torchvision_available",
"is_torch_tpu_available",
"is_torchvision_available",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_torch_xla_available,
is_torch_tpu_available,
is_torch_xla_available,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be the next release - so would need to be removed now! As it's a public object, it should go through at least two cycles

Suggested change
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. "
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me - thanks for iterating and explaining the design choices!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants