Skip to content

Conversation

@will-cromar
Copy link
Contributor

@will-cromar will-cromar commented Feb 20, 2024

What does this PR do?

#2176 replaces the TPU device type with XLA, letting us use GPUs with accelerate now 🎊

This PR fixes some issues that pop up on TPU after that PR:

  • Don't check the xm.xla_device in is_torch_xla_available. Calling xm.xla_device before xmp.spawn causes issues. This causes torch_xla to initialize the runtime parent process, reserving some space on GPU that can't be used by the child processes and causing TPU workloads to outright crash (message below).
  • Fix menu of options in accelerate config to offer XLA as an option. Selecting TPU causes an error because that device type no longer exists.
  • Allow bf16 mixed precision on TPU. Matches old behavior before Make torch xla available on GPU #2176.

Currently, running accelerate on TPU causes this crash due to the first issue:

...
F0000 00:00:1708382221.197251   23274 pjrt_registry.cc:117] Non-OK-status: pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status() status: ALREADY_EXISTS: PJRT_Api already exists for device type tpu
...

Tested accelerate test on TPU v4-8.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

cc @muellerzr @anw90 @vanbasten23

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@anw90
Copy link
Contributor

anw90 commented Feb 21, 2024

  • Don't check the xm.xla_device if we don't need to know the device type in is_torch_xla_available. Calling xm.xla_device before xmp.spawn causes issues. This causes torch_xla to initialize the runtime parent process, reserving some space on GPU that can't be used by the child processes and causing TPU workloads to outright crash (message below). (Can we just check torch_xla.runtime.device_type() instead? @anw90)

Sorry for the code that breaks the task on TPU. In one of my earliest versions, I checked the device_type using the PJRT_DEVICE value in is_torch_xla_available. Later, I changed it to the current implementation to decouple it from outside environments. I think it's okay to use torch_xla.runtime.device_type() to check the device type if there is a crash on TPU for the current implementation.

@will-cromar
Copy link
Contributor Author

Sorry for the code that breaks the task on TPU. In one of my earliest versions, I checked the device_type using the PJRT_DEVICE value in is_torch_xla_available. Later, I changed it to the current implementation to decouple it from outside environments. I think it's okay to use torch_xla.runtime.device_type() to check the device type if there is a crash on TPU for the current implementation.

No worries. This is a subtle bug on GPU, and unfortunately we don't have any TPU CI set up in this repository.

I'll go ahead and replace this check with torch_xla.runtime.device_type() since it's more straightforward than digging into the device hardware type and less risky.

@anw90
Copy link
Contributor

anw90 commented Feb 22, 2024

LGTM, thanks!

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big fan of not having a try/catch + complicated logic there. Very nice!

And glad to see these bugs are fixed. To be clear: now Accelerate won't crash on TPU-XLA? :) (I think we had an issue in Transformers about it)

Also please run make style; make quality to fix the quality check :)

@will-cromar
Copy link
Contributor Author

Also please run make style; make quality to fix the quality check :)

Oops, fixed.

To be clear: now Accelerate won't crash on TPU-XLA?

There's still an outstanding issue on TPU v2 and v3 that @vanbasten23 is working on. accelerate test won't crash on v4 and v5 after this change.

@muellerzr
Copy link
Contributor

Great, thanks @will-cromar!

Also found the transformers issue for posterity: huggingface/transformers#28204

@muellerzr muellerzr merged commit c0b441f into huggingface:main Feb 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants