fix: guard flash-attn rotary import#42679
Conversation
eff8da3 to
d53627d
Compare
There was a problem hiding this comment.
Code Review
This pull request updates the ApplyRotaryEmb class to check for the specific submodule flash_attn.ops.triton.rotary instead of just the base flash_attn package before attempting to import it. This ensures that the rotary embedding logic correctly handles environments where the base package is present but the required Triton operations are missing. Additionally, a unit test was added to verify this behavior using mocks. I have no feedback to provide as there were no review comments.
| def test_apply_rotary_emb_skips_flash_attn_without_rotary_module(monkeypatch): | ||
| monkeypatch.setattr(rotary_common.current_platform, "is_cpu", lambda: False) | ||
|
|
||
| def fake_find_spec(name): | ||
| if name == "flash_attn": | ||
| return object() | ||
| if name == "flash_attn.ops.triton.rotary": | ||
| return None | ||
| raise AssertionError(f"unexpected import probe: {name}") | ||
|
|
||
| monkeypatch.setattr(rotary_common, "find_spec", fake_find_spec) | ||
|
|
||
| compilation_config = CompilationConfig(custom_ops=["all"]) | ||
| monkeypatch.setattr( | ||
| custom_op_module, | ||
| "get_cached_compilation_config", | ||
| lambda: compilation_config, | ||
| ) | ||
|
|
||
| op = rotary_common.ApplyRotaryEmb() | ||
|
|
||
| assert op.apply_rotary_emb_flash_attn is None | ||
|
|
||
|
|
There was a problem hiding this comment.
| def test_apply_rotary_emb_skips_flash_attn_without_rotary_module(monkeypatch): | |
| monkeypatch.setattr(rotary_common.current_platform, "is_cpu", lambda: False) | |
| def fake_find_spec(name): | |
| if name == "flash_attn": | |
| return object() | |
| if name == "flash_attn.ops.triton.rotary": | |
| return None | |
| raise AssertionError(f"unexpected import probe: {name}") | |
| monkeypatch.setattr(rotary_common, "find_spec", fake_find_spec) | |
| compilation_config = CompilationConfig(custom_ops=["all"]) | |
| monkeypatch.setattr( | |
| custom_op_module, | |
| "get_cached_compilation_config", | |
| lambda: compilation_config, | |
| ) | |
| op = rotary_common.ApplyRotaryEmb() | |
| assert op.apply_rotary_emb_flash_attn is None |
Thanks for the work! For this small update, we don't need a specific unit test
|
@he-yufeng are you able to finish this? otherwise, i can take over |
d53627d to
d42c6dc
Compare
|
I pushed d42c6dc and removed the dedicated unit test per review, keeping the import guard focused on the production change. Local validation: py_compile, ruff check, ruff format --check, and git diff --check. CI is running now. |
d42c6dc to
1114322
Compare
|
Rebased the branch onto the latest upstream/main and kept the PR scoped to the production import guard only. Validation run locally:
|
1114322 to
ed20d62
Compare
|
Pushed Root cause from the failed Buildkite logs:
Validation run locally: I also reproduced the failing condition with the same helper logic in a local Python snippet: missing |
|
|
||
| self.apply_rotary_emb_flash_attn = None | ||
| if not current_platform.is_cpu() and find_spec("flash_attn") is not None: | ||
| if not current_platform.is_cpu() and _has_flash_attn_rotary(): |
There was a problem hiding this comment.
| if not current_platform.is_cpu() and _has_flash_attn_rotary(): | |
| if not current_platform.is_cpu() and find_spec("flash_attn.ops.triton.rotary") is not None(): |
Why can't we do like this?
There was a problem hiding this comment.
find_spec("flash_attn.ops.triton.rotary") is not None was the first shape I tried, but it still raises when the parent package is absent.
The Buildkite failure on the previous head was exactly that case:
ModuleNotFoundError: No module named 'flash_attn'
...
find_spec("flash_attn.ops.triton.rotary")
find_spec() only returns None when the searched module is absent under an importable parent package. If flash_attn itself is not installed, it raises ModuleNotFoundError while resolving the parent. The helper keeps the production guard to the same check, but catches that missing-parent case and treats it as unavailable.
There was a problem hiding this comment.
I see, please also add the full reproduce command in main and error report in PR description, then the same command pass in this branch, I will take a look
|
Updated the PR description with the requested main-branch repro command, error text, branch behavior, and current validation/CI status. |
yewentao256
left a comment
There was a problem hiding this comment.
Thanks @he-yufeng
Seems your issue has been fixed in main, or there is an issue with your current env
[yewentao256@nm-frk-h200-03-preserve vllm-source]$ python - <<'PY'
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
ApplyRotaryEmb(
rotary_dim=128,
neox_style=True,
head_size=128,
is_segmented=False,
)
PY
Traceback (most recent call last):
File "<stdin>", line 3, in <module>
TypeError: ApplyRotaryEmb.__init__() got an unexpected keyword argument 'rotary_dim'The error is from your command line instead of the missing of import.
|
Close this PR as not scheduled. Feel free to reopen if I am wrong |
|
@yewentao256 i dont think main fixes the issue. the issue here is with calling check #42675 |
|
@joonyoo181 Please update FA4 to newer version like pip show flash-attn-4
Name: flash-attn-4
Version: 4.0.0b16
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email:
License: BSD 3-Clause License
Location: /home/yewentao256/.venv/lib/python3.12/site-packages
Requires: apache-tvm-ffi, einops, nvidia-cutlass-dsl, quack-kernels, torch, torch-c-dlpack-ext, typing_extensions
Required-by:Or uninstall it, vLLM has its own FA compilation. |
|
Opps, this can be reproduce by |
| @@ -135,7 +142,7 @@ def __init__( | |||
| self.enable_fp32_compute = enable_fp32_compute | |||
|
|
|||
| self.apply_rotary_emb_flash_attn = None | |||
| if not current_platform.is_cpu() and find_spec("flash_attn") is not None: | |||
| if not current_platform.is_cpu() and _has_flash_attn_rotary(): | |||
There was a problem hiding this comment.
| if not current_platform.is_cpu() and _has_flash_attn_rotary(): | |
| if not current_platform.is_cpu(): | |
| try: | |
| apply_rotary = import_module( | |
| "flash_attn.ops.triton.rotary" | |
| ).apply_rotary | |
| except ModuleNotFoundError: | |
| apply_rotary = None |
I'd prefer this fix
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Assisted-by: OpenAI Codex
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Assisted-by: OpenAI Codex
ed20d62 to
f53d644
Compare
|
Implemented the direct-import approach in f53d644. The code now attempts I also added the repository-required AI assistance disclosure and |
yewentao256
left a comment
There was a problem hiding this comment.
LGTM, thanks for the work and iterations!
|
Checked the remaining red Buildkite shard. The failing job is
The failure is in Could someone rerun that Buildkite shard or the PR build? |
yewentao256
left a comment
There was a problem hiding this comment.
Let's retry once before force merge
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: divineearthly <divineearthly@gmail.com>
Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Fixes #42675.
Summary
flash_attn.ops.triton.rotarydirectly when vLLM is not on CPUApplyRotaryEmbconstructionFA4 can leave the
flash_attnroot package importable while moving or removingflash_attn.ops.triton.rotary. Checking only the root package therefore selects a fast path that fails during import. The branch now follows the maintainer-requested direct import shape and catchesModuleNotFoundErroraround that exact import.Duplicate work check
I searched open, closed, and recently merged vLLM PRs for #42675,
flash_attn.ops.triton.rotary,ApplyRotaryEmb, and the FA4 import failure before opening and updating this PR. I did not find another active fix for this path.AI assistance disclosure
I used OpenAI Codex to help inspect the repository, update the implementation after review, and run local validation. I reviewed the final diff and can explain the failure mode and fallback behavior. The commits include the required
Assisted-byattribution trailers.Reproduce on
mainWith an FA4 environment where
flash_attnexists butflash_attn.ops.triton.rotarydoes not, run:mainimports the missing rotary module after probing only the rootflash_attnpackage, which raisesModuleNotFoundError. On this branch, the direct rotary import is guarded and the existing fallback remains active.Validation
Passed locally on Windows:
python -m py_compile vllm/model_executor/layers/rotary_embedding/common.pypython -m ruff check vllm/model_executor/layers/rotary_embedding/common.pypython -m ruff format --check vllm/model_executor/layers/rotary_embedding/common.pygit diff --checkThe previous branch head also passed the vLLM Buildkite PR suite. I do not have the required FA4/GPU environment to rerun the full latency reproduction locally.