fix: MLLM continuous batching for hybrid models#165
fix: MLLM continuous batching for hybrid models#165Thump604 wants to merge 20 commits intowaybarrios:mainfrom
Conversation
Multi-turn conversations with tool results can produce messages where system messages are out of order or consecutive same-role messages appear. Qwen 3.5's chat template rejects these with "System message must be at the beginning", crashing CLI agents on turn 2. Add _normalize_messages() to SimpleEngine.stream_chat() to: 1. Map developer -> system (OpenAI Responses API compat) 2. Merge consecutive same-role messages (alternating-role requirement) This matches the normalization already done in BatchedEngine paths (PR waybarrios#165) but was missing from SimpleEngine's MTP text path.
|
Hi @waybarrios — friendly ping on this PR. CI is green and it's been open for a few days. I have a few more PRs in the queue (#169, #171, #173, #174, #175, #177, #180) that build on this work — happy to walk through any of them or adjust based on your feedback. Let me know if there's anything blocking review. Thanks! |
Comprehensive tests for MLLM continuous batching with hybrid model caches (KVCache + ArraysCache). Covers merge, filter, extract, extend operations on mixed cache lists, plus message normalization for real-world client formats (OpenCode consecutive same-role messages). Tests written first — implementation follows.
Support ArraysCache (SSM layers), RotatingKVCache, and CacheList in the MLLM batch cache factory. Matches the pattern from mlx-lm's native BatchGenerator. Uses type(c) is KVCache (strict identity) to avoid catching QuantizedKVCache subclass.
Replace isinstance(sample_cache, KVCache) with a capability check (hasattr merge). This is the actual crash point for hybrid models like Qwen 3.5 where layer 0 is ArraysCache (GatedDeltaNet). The merge loop is already polymorphic — each cache type's merge() returns the correct batched representation.
Replace hasattr(o, 'keys') check with empty() which is universal across all cache types via _BaseCache. ArraysCache uses cache[0] is None for empty(), BatchKVCache uses keys is None. This fixes silent skip of SSM layer caches during batch extension.
Add _normalize_messages() preprocessing that merges consecutive messages with the same role. Prevents chat template failures when clients like OpenCode send system+system+user+user format that Qwen 3.5 and other templates reject. Only merges string content; multimodal list content is preserved as-is.
OpenAI Responses API and clients like Claude Code send messages with role "developer" instead of "system". Chat templates (Qwen 3.5, Llama, etc.) don't recognize this role, causing template failure → raw prefill fallback → potential crash during generation. Add a _ROLE_MAP dict to _normalize_messages() that maps non-standard roles before the merge logic runs. This ensures developer + system messages also merge correctly when consecutive. Closes waybarrios#137
Fixes ruff F401 lint error.
_can_trim_cache() only checked the first cache layer's is_trimmable(). For hybrid models (Qwen 3.5 MoE, Nemotron Mamba+Attention), the prompt_cache mixes KVCache (trimmable) and ArraysCache (not trimmable). The first layer happened to be KVCache, so _can_trim_cache returned True. _trim_cache then trimmed KVCache layers but silently skipped ArraysCache layers, leaving KV and SSM/MoE state inconsistent. Fix: check ALL cache layers, not just the first. Also add is_trimmable guard in _trim_cache to skip non-trimmable caches explicitly. Relates to waybarrios#145
…s#142, waybarrios#136) Add _is_kv_layer() to classify positional (KVCache) vs non-positional (ArraysCache) cache layers. _extract_block_tensor_slice() now skips non-KV layers instead of crashing with 'Too many indices for array with 3 dimensions'. NonKVCacheData dataclass added for storing non-positional state (used by subsequent commits for full hybrid reconstruction).
store_cache() separates KV (block-sliced) and non-KV (stored whole) layers. reconstruct_cache() rebuilds both: KV via block concatenation, non-KV via from_state(). If non-KV states are missing for a hybrid model, returns None to force safe recomputation.
… robustness fetch_cache() now rejects partial prefix matches when non-KV states are present but don't match the candidate block set. release_cache() and clear() clean up non-KV state. scheduler fallback guards against non-4D tensors in cache state reconstruction.
- fork_cache() now copies has_non_kv from source entry - reconstruct_cache() uses dict lookup instead of list.index() in the per-layer loop (O(n) → O(1) per iteration)
…rror handling, streaming detokenizer Integrate four improvements from PR waybarrios#140 (janhilgard) into the hybrid model continuous batching implementation: - patches/qwen3_5_mllm.py: Monkey-patch Qwen3_5Attention.__call__ to convert BatchKVCache mx.array offsets to int for mask slicing compatibility. Without this, batched generation crashes with "Slice indices must be integers or None". - utils/tokenizer.py: Detect VLM models (vision_config + text_config) upfront and load with strict=False directly, avoiding a double-load of ~100GB weights that can cause OOM on memory-constrained systems. Also handles extra parameters (MTP/vision) with traceback cleanup. - mllm_batch_generator.py: Per-request error handling in _process_prompts() — failed preprocessing returns finish_reason="error" instead of crashing the entire batch. Oversized prompt check downgraded from hard error to warning. - mllm_scheduler.py: NaiveStreamingDetokenizer for UTF-8-safe incremental decode. Replaces raw tokenizer.decode([token]) which can produce mojibake on multi-byte characters. Co-authored-by: janhilgard <[email protected]>
…nizer for Unicode (waybarrios#130) Two fixes in scheduler.py: 1. _chunked_next tuple unpack (issue waybarrios#178): mlx-lm 0.31.x added prompt_checkpoints as a 7th tuple element. _chunked_next only unpacked 6, crashing when prefix cache triggers chunked prefill. Same class of bug as PR waybarrios#169 but in a different code path. 2. NaiveStreamingDetokenizer for BatchedEngine (issue waybarrios#130): Raw tokenizer.decode([token]) splits multi-byte codepoints (emoji, CJK) into surrogate pairs in streaming output. Replace with NaiveStreamingDetokenizer that buffers incomplete UTF-8 byte sequences and only emits valid segments. Matches the fix applied to mllm_scheduler.py in commit d2ea97c.
… in async_eval Two memory fixes in the MTP decode path: 1. Skip recurrent state snapshots in optimistic mode. Every decode step was copying all 36 GatedDeltaNet SSM states (4 MB each = 147 MB per step) for rollback on MTP rejection — but optimistic mode never rejects. The lazy .copy() graph nodes held references to pre-verify Metal buffers, preventing them from being freed. With 2-3 steps of GPU pipeline depth, this created 300-450 MB of unnecessary memory pressure per step, contributing to OOM on long generations. 2. Include batch.tokens in mx.async_eval to collapse the token history concatenation chain. Without this, the lazy concat graph grows unboundedly between periodic eval calls. Root cause of Metal OOM during RULER 32K benchmark (active=108.7GB, peak=119.4GB on 128GB M2 Ultra). Memory grew ~2 GB per 256 decode steps due to accumulating unreleased Metal buffers.
1. Increase _clear_cache_interval from 32 to 256. Each mx.clear_cache() triggers Metal buffer reclamation which stalls the GPU pipeline. For single-request generation, clearing every 32 steps causes 125 stalls per 4000-token generation with no benefit (no competing requests to reclaim memory for). 2. Pass output_token_ids by reference instead of copying the full list every step in _process_batch_responses. At 4000 tokens, the old code copied 4000 ints per step — pure Python overhead.
1cc6d82 to
66bf4aa
Compare
|
Rebased against current main (post #180, #97, #127 merges). All CI green. This is the PR you endorsed on #140 — includes the integrated patches from janhilgard (BatchKVCache attention offset, VLM tokenizer loader, error handling, streaming detokenizer) plus the hybrid cache support for Qwen3.5 MoE continuous batching and Tested on M2 Ultra with Qwen3.5-122B-A10B (MoE, hybrid ArraysCache+KVCache) — continuous batching with prefix cache, zero cache mismatches across hundreds of requests. |
Integrates the admission controller, cooperative specprefill, and MLLM+MTP per-request routing into the BatchedEngine for production multi-user serving. Key changes: - BatchedEngine: admission gates on all 4 public methods (chat, stream_chat, generate, stream_generate) with try/finally cleanup - MLLM+MTP routing: text-only requests → mlx_lm TextModel with MTP speculative decoding, media requests → mlx_vlm MLLM path - System KV cache: prefix boundary detection + snapshot/restore for repeated system prompts (7x speedup on cache hits) - Cooperative specprefill: draft scoring outside the generation lock, yielding between chunks for concurrent request progress - Thread-safe snapshot access (threading.Lock for cross-thread reads/writes) - Cache-hit re-verification under lock (prevents stale flag after queuing) - MLLM error loop: breaks after 10 consecutive errors (no infinite loop) - CLI: --scheduler-policy, --scheduler-headroom-gb flags Depends on: admission controller PR, cooperative specprefill PR, waybarrios#165, waybarrios#180 New files: - specprefill.py: SpecPrefill scoring + sparse prefill (builds on merged waybarrios#180) - text_model_from_vlm.py: zero-copy TextModel construction from VLM backbone
Split from waybarrios#165 — prefix cache hybrid changes deferred to waybarrios#217. Fixes: - mllm_batch_generator: hybrid cache handling (ArraysCache + KVCache interleaved) for make_batch_cache, merge, filter, extract, extend - mllm_scheduler: hybrid cache scheduling - server.py: _normalize_messages (developer->system role mapping, consecutive same-role merge) applied to MLLM and LLM paths - tokenizer: VLM tokenizer loader with fallback - qwen3_5_mllm: Qwen3.5 MLLM patch for hybrid batching Fixes waybarrios#137 (developer role crash), fixes OpenCode multi-system crash.
|
Splitting this PR. The non-prefix-cache fixes (hybrid batch generator, message normalization, scheduler, tokenizer) are now in #224. The prefix_cache.py changes here overlap with #217, which takes a cleaner approach (storage-type dispatch model with concat vs latest semantics). Deferring to that PR for prefix cache hybrid support. #224 can be reviewed and merged independently. Once #217 matures, the prefix cache portion of this PR becomes redundant. |
MLLM continuous batching for hybrid VLM+MTP models.
What: Fix
_make_batch_cache,MLLMBatch.extend(), andBlockAwarePrefixCacheto handle hybrid models where some layers have KV cache and others don't. Also adds_normalize_messages()which mapsdeveloper→system, merges consecutive same-role messages, and hoists system messages to position [0].Why: Without this, BatchedEngine crashes on VLM+MTP models (Qwen3.5) because the cache structure doesn't match the hybrid layer layout. The normalization fixes are needed because Qwen 3.5 chat templates reject malformed message sequences.
Files:
engine/batched.py,prefix_cache.py,engine/simple.pyTest: