Skip to content

Conversation

@haqatak
Copy link
Contributor

@haqatak haqatak commented Sep 9, 2025

This change adds support for Apple Silicon, allowing the library to run on MPS-compatible devices.

Key changes:

  • Modified xfuser/envs.py to detect and support the MPS backend.
  • Replaced hardcoded .cuda() calls with dynamic device placement (.to(device)).
  • Made CUDA-specific dependencies like yunchang and flash-attn conditional, preventing import errors on non-CUDA systems.
  • Updated setup.py to be more platform-agnostic.
  • Added a new test file tests/core/test_envs.py to verify the new device-detection logic.

Limitation:
The yunchang library for long context attention is CUDA-specific and is therefore disabled when running on MPS. This means that long context attention features will not be available on Apple Silicon devices.

This change adds support for Apple Silicon, allowing the library to run on MPS-compatible devices.

Key changes:
- Modified `xfuser/envs.py` to detect and support the MPS backend.
- Replaced hardcoded `.cuda()` calls with dynamic device placement (`.to(device)`).
- Made CUDA-specific dependencies like `yunchang` and `flash-attn` conditional, preventing import errors on non-CUDA systems.
- Updated `setup.py` to be more platform-agnostic.
- Added a new test file `tests/core/test_envs.py` to verify the new device-detection logic.

Limitation:
The `yunchang` library for long context attention is CUDA-specific and is therefore disabled when running on MPS. This means that long context attention features will not be available on Apple Silicon devices.
This commit fixes a `TypeError` that occurred when running the library on an MPS device. The error was caused by `get_device_version()` returning `None` for MPS, which `packaging.version.parse()` cannot handle.

The fix modifies the lambda function for `CUDA_VERSION` in `xfuser/envs.py` to provide a default version string ("0.0") when `get_device_version()` returns `None`.

A new test case has also been added to `tests/core/test_envs.py` to ensure that accessing `envs.CUDA_VERSION` in a mocked MPS environment does not raise an error, preventing future regressions.
This commit pins the `torch` version to `2.4.1` in `setup.py`. This is to resolve dependency conflicts with other libraries like `torchaudio` and `torchvision` that may be present in the user's environment and require a specific version of `torch`.

This change will ensure a more stable and predictable installation process for users on all platforms.
packages=find_packages(),
install_requires=[
"torch>=2.1.0",
"torch==2.4.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it as torch>=2.4.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants