Skip to content

Conversation

@rebel-jaehwang
Copy link
Contributor

@rebel-jaehwang rebel-jaehwang commented Nov 24, 2025

🚀 Summary of Changes

Enable sliding window attention models for torch.compile path.

Problem:
RBLN's SWA kernel always uses a single block of size window_size. When the block gets full, the kernel slides the contents of the window in-place (a bit more details in the custom op definitions in vllm_rbln/v1/attention/backends/flash_attention.py). This is quite different from usual attention implementations (append-only KV cache), and thus is not compatible with vLLM's KV cache manager. Specifically, vLLM's manager may allocate more blocks than necessary; deallocate the block in use when the window logically goes out of the block; and corrupt KV cache in the case of KV cache hit.

optimum-rbln handles this issue by separately managing blocks for SWA layers. We could port this logic to torch.compile path, but we'd like to leverage vLLM's existing functionality while minimizing RBLN-specific code as much as possible.

Solution: Enable vLLM's hybrid KV cache manager and provide a SingleTypeKVCacheManager for RBLN SWA. This manager

  • Allocates a single fixed block per request.
  • Disables prefix caching.

📌 Related Issues / Tickets


✅ Type of Change

  • ✨ Feature (feature)
  • 🧠 Model support (model)
  • 🧬 Core engine changes (core)
  • 🛠 Bug fix (bug-fix)
  • ⚙️ Performance improvement (perf)
  • 🔁 Refactor or code cleanup (refactor)
  • 📄 Documentation (docs)
  • ❓ Other (other): please describe

🧪 How to Test

Mistral-7B-Instruct-v0.1 (swa-only)

VLLM_RBLN_USE_VLLM_MODEL=1 RBLN_KERNEL_MODE=triton RBLN_DEVICES=0 python examples/experimental/offline_inference_basic.py --model mistralai/Mistral-7B-Instruct-v0.1 --block-size 4096

Result

Prompt: 'Hello, my name is', Generated text: ' [Your Name], and I am an AI language model. How can I assist'
Prompt: 'The president of the United States is', Generated text: ' the head of state and the head of government of the United States. The president'
Prompt: 'The capital of France is', Generated text: ' Paris, a city of love renowned for its art, culture, and gas'
Prompt: 'The future of AI is', Generated text: ' bright, and it’s already transforming many aspects of our lives. From'

gemma-3-4b-it (hybrid)

VLLM_RBLN_USE_VLLM_MODEL=1 VLLM_RBLN_DISABLE_MM=1 RBLN_KERNEL_MODE=triton RBLN_DEVICES=0 python examples/experimental/offline_inference_basic.py --model google/gemma-3-4b-it

Result

Prompt: 'Hello, my name is', Generated text: " Alex. I'm a passionate and dedicated software engineer with a strong interest in"
Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States. The president is'
Prompt: 'The capital of France is', Generated text: ' Paris.\n\nParis is a global center for art, fashion, gastronomy and culture'
Prompt: 'The future of AI is', Generated text: " uncertain, but one thing is clear: it's going to change the way"

📋 Checklist

  • PR title follows Conventional Commits format
  • This PR is linked to an existing issue
  • The test method is described, and the expected result is clearly stated
  • Relevant documentation has been updated (if applicable)

💬 Notes

torch.export.export error

The compiler prints the following error message during torch.export.export and falls back to torch.jit.trace.

[torch/_subclasses/fake_tensor.py:2721] [0/0] failed while attempting to run meta for aten.set_.source_Storage
[torch/_subclasses/fake_tensor.py:2721] [0/0] Traceback (most recent call last):
[torch/_subclasses/fake_tensor.py:2721] [0/0]   File "/vllm-rbln/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
[torch/_subclasses/fake_tensor.py:2721] [0/0]     r = func(*args, **kwargs)
[torch/_subclasses/fake_tensor.py:2721] [0/0]   File "/vllm-rbln/.venv/lib/python3.10/site-packages/torch/_ops.py", line 829, in __call__
[torch/_subclasses/fake_tensor.py:2721] [0/0]     return self._op(*args, **kwargs)
[torch/_subclasses/fake_tensor.py:2721] [0/0] RuntimeError: Expected !size_bytes_is_heap_allocated_ to be true, but got false.

While this specific error message is new, other models fall back to torch.jit.trace anyway. So I'm not looking further into this error for now.

Supporting window_size < block_size

The current implementation assumes window_size == block_size. But we should support smaller window_size while maintaining big enough block_size.

Potential solutions

  • vLLM 0.11.1 added feature to split the block managed by KV cache manager into the size required by attention backend in vllm#24486. We should see if this works well for us when we bump vllm-rbln version.
  • If that doesn't work, we could rewrite kernel to support window_size < block_size case.

Supporting prefix caching

The current implementation disables prefix caching for SWA, because our kernel mutates the KV cache block.

One possible approach would be:

  • maintain a separate append-only KV cache block in addition to the existing sliding-contents block (introducing a minor mem access overhead that would be mostly hidden by compute ops)
  • let vLLM KV cache manager handle the append-only KV cache blocks
  • on prefix cache hit, copy (D2D) the contents to the new request's sliding-contents block.

However, this would require quite a bit of change in several places, and we have not yet decided on how to handle prefix caching for torch.compile path. So I plan to revisit this later.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables sliding window attention (SWA) models for the torch.compile path by implementing a custom KV cache manager that accommodates RBLN's unique sliding window kernel behavior. The solution leverages vLLM's hybrid KV cache manager framework to handle SWA models that slide cache contents in-place rather than using append-only semantics.

Key changes:

  • Introduced RBLNSlidingWindowSpec and RBLNSlidingWindowManager for custom KV cache handling
  • Added sliding window attention custom ops for prefill and decode phases
  • Extended attention metadata to support sliding window parameters

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
vllm_rbln/v1/kv_cache.py New file implementing RBLN-specific sliding window KV cache manager with single-block allocation
vllm_rbln/v1/worker/rbln_model_runner.py Updated to use RBLNSlidingWindowSpec and support multiple KV cache groups in warmup
vllm_rbln/v1/attention/backends/flash_attention.py Added sliding window attention custom ops and integrated them into the attention forward path
vllm_rbln/v1/worker/optimum_model_runner.py Added multimodal input disabling support
vllm_rbln/rbln_envs.py Added VLLM_RBLN_DISABLE_MM environment variable
vllm_rbln/platform.py Added hybrid KV cache support and prefix caching configuration
examples/experimental/offline_inference_basic.py Made example script configurable via command-line arguments

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +830 to 833
# unused?
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)

Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'unused?' indicates uncertainty about whether self.need_mask is used. If it is genuinely unused, it should be removed to avoid confusion and reduce dead code. If it is used elsewhere, the comment should be removed or clarified.

Suggested change
# unused?
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +88
single_type_kv_cache_manager.spec_manager_map.update({
RBLNSlidingWindowSpec:
RBLNSlidingWindowManager,
})

This comment was marked as off-topic.

Comment on lines +582 to 584
# TODO: construct only when not envs.RBLN_FLASH_CAUSAL_ATTN
attn_masks: Optional[torch.Tensor] = None
kv_caches: Optional[list[torch.Tensor]] = None
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The TODO comment suggests conditional construction of attn_masks based on envs.RBLN_FLASH_CAUSAL_ATTN. Consider addressing this optimization or creating a tracking issue if this is deferred work.

Copilot uses AI. Check for mistakes.
self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
if envs.VLLM_RBLN_DISABLE_MM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rebel-jaehwang
I wonder why this new environment variable was introduced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary for functionality of this PR, but useful for testing. Specifically, I wanted to test google/gemma-3-4b-it (full/swa hybrid with window_size=1k), but it is multi-modal model, which isn't supported in torch.compile path yet. Setting supports_mm_inputs allows us to test the underlying language model only.

so this env var doesn't need to be checked in optimum runner but I added it there anyway for consistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma3 model in optimum-rbln, which uses sliding-window attention, works well with multimodal inputs. So I think this is not aligned with optimum-rbln.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in optimum runner

# NOTE - force dtype into fp16 for eager mode
model_config.dtype = torch.float16

cls.disable_unsupported_prefix_caching(vllm_config)
Copy link
Contributor

@rebel-eunji rebel-eunji Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rebel-jaehwang
The cls.disable_unsupported_prefix_caching setting is adopted from optimum and enables prefix caching for decoder-only generation models. Is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention was to disable prefix caching for SWA models, which seems to work as I expected. But I just noticed this causes validation errors when running models not registered in optimum path... Maybe I would need to reimplement similar stuff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Collaborator

@rebel-jiwoopark rebel-jiwoopark left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@rebel-jaehwang rebel-jaehwang merged commit a1d6fd5 into dev Dec 2, 2025
10 checks passed
@rebel-jaehwang rebel-jaehwang deleted the swa branch December 2, 2025 07:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

torch.compile torch.compile based implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants