-
Notifications
You must be signed in to change notification settings - Fork 5
feat: sliding window attention #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This can be useful for independently testing the main language model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables sliding window attention (SWA) models for the torch.compile path by implementing a custom KV cache manager that accommodates RBLN's unique sliding window kernel behavior. The solution leverages vLLM's hybrid KV cache manager framework to handle SWA models that slide cache contents in-place rather than using append-only semantics.
Key changes:
- Introduced
RBLNSlidingWindowSpecandRBLNSlidingWindowManagerfor custom KV cache handling - Added sliding window attention custom ops for prefill and decode phases
- Extended attention metadata to support sliding window parameters
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_rbln/v1/kv_cache.py | New file implementing RBLN-specific sliding window KV cache manager with single-block allocation |
| vllm_rbln/v1/worker/rbln_model_runner.py | Updated to use RBLNSlidingWindowSpec and support multiple KV cache groups in warmup |
| vllm_rbln/v1/attention/backends/flash_attention.py | Added sliding window attention custom ops and integrated them into the attention forward path |
| vllm_rbln/v1/worker/optimum_model_runner.py | Added multimodal input disabling support |
| vllm_rbln/rbln_envs.py | Added VLLM_RBLN_DISABLE_MM environment variable |
| vllm_rbln/platform.py | Added hybrid KV cache support and prefix caching configuration |
| examples/experimental/offline_inference_basic.py | Made example script configurable via command-line arguments |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # unused? | ||
| self.need_mask = (self.alibi_slopes is not None | ||
| or self.sliding_window is not None) | ||
|
|
Copilot
AI
Nov 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment 'unused?' indicates uncertainty about whether self.need_mask is used. If it is genuinely unused, it should be removed to avoid confusion and reduce dead code. If it is used elsewhere, the comment should be removed or clarified.
| # unused? | |
| self.need_mask = (self.alibi_slopes is not None | |
| or self.sliding_window is not None) |
| single_type_kv_cache_manager.spec_manager_map.update({ | ||
| RBLNSlidingWindowSpec: | ||
| RBLNSlidingWindowManager, | ||
| }) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| # TODO: construct only when not envs.RBLN_FLASH_CAUSAL_ATTN | ||
| attn_masks: Optional[torch.Tensor] = None | ||
| kv_caches: Optional[list[torch.Tensor]] = None |
Copilot
AI
Nov 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The TODO comment suggests conditional construction of attn_masks based on envs.RBLN_FLASH_CAUSAL_ATTN. Consider addressing this optimization or creating a tracking issue if this is deferred work.
| self.uses_mrope = model_config.uses_mrope | ||
| self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( | ||
| model_config) | ||
| if envs.VLLM_RBLN_DISABLE_MM: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rebel-jaehwang
I wonder why this new environment variable was introduced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary for functionality of this PR, but useful for testing. Specifically, I wanted to test google/gemma-3-4b-it (full/swa hybrid with window_size=1k), but it is multi-modal model, which isn't supported in torch.compile path yet. Setting supports_mm_inputs allows us to test the underlying language model only.
so this env var doesn't need to be checked in optimum runner but I added it there anyway for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemma3 model in optimum-rbln, which uses sliding-window attention, works well with multimodal inputs. So I think this is not aligned with optimum-rbln.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in optimum runner
vllm_rbln/platform.py
Outdated
| # NOTE - force dtype into fp16 for eager mode | ||
| model_config.dtype = torch.float16 | ||
|
|
||
| cls.disable_unsupported_prefix_caching(vllm_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rebel-jaehwang
The cls.disable_unsupported_prefix_caching setting is adopted from optimum and enables prefix caching for decoder-only generation models. Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intention was to disable prefix caching for SWA models, which seems to work as I expected. But I just noticed this causes validation errors when running models not registered in optimum path... Maybe I would need to reimplement similar stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rebel-jiwoopark
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
🚀 Summary of Changes
Enable sliding window attention models for torch.compile path.
Problem:
RBLN's SWA kernel always uses a single block of size
window_size. When the block gets full, the kernel slides the contents of the window in-place (a bit more details in the custom op definitions invllm_rbln/v1/attention/backends/flash_attention.py). This is quite different from usual attention implementations (append-only KV cache), and thus is not compatible with vLLM's KV cache manager. Specifically, vLLM's manager may allocate more blocks than necessary; deallocate the block in use when the window logically goes out of the block; and corrupt KV cache in the case of KV cache hit.optimum-rbln handles this issue by separately managing blocks for SWA layers. We could port this logic to torch.compile path, but we'd like to leverage vLLM's existing functionality while minimizing RBLN-specific code as much as possible.
Solution: Enable vLLM's hybrid KV cache manager and provide a
SingleTypeKVCacheManagerfor RBLN SWA. This manager📌 Related Issues / Tickets
✅ Type of Change
feature)model)core)bug-fix)perf)refactor)docs)other): please describe🧪 How to Test
Mistral-7B-Instruct-v0.1 (swa-only)
Result
gemma-3-4b-it (hybrid)
Result
📋 Checklist
💬 Notes
torch.export.exporterrorThe compiler prints the following error message during
torch.export.exportand falls back totorch.jit.trace.While this specific error message is new, other models fall back to
torch.jit.traceanyway. So I'm not looking further into this error for now.Supporting
window_size < block_sizeThe current implementation assumes
window_size == block_size. But we should support smallerwindow_sizewhile maintaining big enoughblock_size.Potential solutions
window_size < block_sizecase.Supporting prefix caching
The current implementation disables prefix caching for SWA, because our kernel mutates the KV cache block.
One possible approach would be:
However, this would require quite a bit of change in several places, and we have not yet decided on how to handle prefix caching for torch.compile path. So I plan to revisit this later.