diff --git a/docs/guides/reasoning.md b/docs/guides/reasoning.md new file mode 100644 index 00000000..45d0ed20 --- /dev/null +++ b/docs/guides/reasoning.md @@ -0,0 +1,267 @@ +# Reasoning Models + +vllm-mlx supports reasoning models that show their thinking process before giving an answer. Models like Qwen3 and DeepSeek-R1 wrap their reasoning in `...` tags, and vllm-mlx can parse these tags to separate the reasoning from the final response. + +## Why Use Reasoning Parsing? + +When a reasoning model generates output, it typically looks like this: + +``` + +Let me analyze this step by step. +First, I need to consider the constraints. +The answer should be a prime number less than 10. +Checking: 2, 3, 5, 7 are all prime and less than 10. + +The prime numbers less than 10 are: 2, 3, 5, 7. +``` + +Without reasoning parsing, you get the raw output with the tags included. With reasoning parsing enabled, the thinking process and final answer are separated into distinct fields in the API response. + +## Getting Started + +### Start the Server with Reasoning Parser + +```bash +# For Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# For DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +### API Response Format + +When reasoning parsing is enabled, the API response includes a `reasoning` field: + +**Non-streaming response:** + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "The prime numbers less than 10 are: 2, 3, 5, 7.", + "reasoning": "Let me analyze this step by step.\nFirst, I need to consider the constraints.\nThe answer should be a prime number less than 10.\nChecking: 2, 3, 5, 7 are all prime and less than 10." + } + }] +} +``` + +**Streaming response:** + +Chunks are sent separately for reasoning and content. During the reasoning phase, chunks have `reasoning` populated. When the model transitions to the final answer, chunks have `content` populated: + +```json +{"delta": {"reasoning": "Let me analyze"}} +{"delta": {"reasoning": " this step by step."}} +{"delta": {"reasoning": "\nFirst, I need to"}} +... +{"delta": {"content": "The prime"}} +{"delta": {"content": " numbers less than 10"}} +{"delta": {"content": " are: 2, 3, 5, 7."}} +``` + +## Using with OpenAI SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What are the prime numbers less than 10?"}] +) + +message = response.choices[0].message +print("Reasoning:", message.reasoning) # The thinking process +print("Answer:", message.content) # The final answer +``` + +### Streaming with Reasoning + +```python +reasoning_text = "" +content_text = "" + +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Solve: 2 + 2 = ?"}], + stream=True +) + +for chunk in stream: + delta = chunk.choices[0].delta + if hasattr(delta, 'reasoning') and delta.reasoning: + reasoning_text += delta.reasoning + print(f"[Thinking] {delta.reasoning}", end="") + if delta.content: + content_text += delta.content + print(delta.content, end="") + +print(f"\n\nFinal reasoning: {reasoning_text}") +print(f"Final answer: {content_text}") +``` + +## Supported Parsers + +### Qwen3 Parser (`qwen3`) + +For Qwen3 models that use explicit `` and `` tags. + +- Requires **both** opening and closing tags +- If tags are missing, output is treated as regular content +- Best for: Qwen3-0.6B, Qwen3-4B, Qwen3-8B and similar models + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +### DeepSeek-R1 Parser (`deepseek_r1`) + +For DeepSeek-R1 models that may omit the opening `` tag. + +- More lenient than Qwen3 parser +- Handles cases where `` is implicit +- Content before `` is treated as reasoning even without `` + +```bash +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +## How It Works + +The reasoning parser uses text-based detection to identify thinking tags in the model output. During streaming, it tracks the current position in the output to correctly route each token to either `reasoning` or `content`. + +``` +Model Output: Step 1: analyze...The answer is 42. + ├─────────────────────┤├─────────────────────┤ +Parsed: │ reasoning ││ content │ + └─────────────────────┘└─────────────────────┘ +``` + +The parsing is stateless and uses the accumulated text to determine context, making it robust for streaming scenarios where tokens may arrive in arbitrary chunks. + +## Tips for Best Results + +### Prompting + +Reasoning models work best when you encourage step-by-step thinking: + +```python +messages = [ + {"role": "system", "content": "Think through problems step by step before answering."}, + {"role": "user", "content": "What is 17 × 23?"} +] +``` + +### Handling Missing Reasoning + +Some prompts may not trigger reasoning. In these cases, `reasoning` will be `None` and all output goes to `content`: + +```python +message = response.choices[0].message +if message.reasoning: + print(f"Model's thought process: {message.reasoning}") +print(f"Answer: {message.content}") +``` + +### Temperature and Reasoning + +Lower temperatures tend to produce more consistent reasoning patterns: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Explain quantum entanglement"}], + temperature=0.3 # More focused reasoning +) +``` + +## Backward Compatibility + +When `--reasoning-parser` is not specified, the server behaves as before: +- Thinking tags are included in the `content` field +- No `reasoning` field is added to responses + +This ensures existing applications continue to work without changes. + +## Example: Math Problem Solver + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +def solve_math(problem: str) -> dict: + """Solve a math problem and return reasoning + answer.""" + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a math tutor. Show your work."}, + {"role": "user", "content": problem} + ], + temperature=0.2 + ) + + message = response.choices[0].message + return { + "problem": problem, + "work": message.reasoning, + "answer": message.content + } + +result = solve_math("If a train travels 120 km in 2 hours, what is its average speed?") +print(f"Problem: {result['problem']}") +print(f"\nWork shown:\n{result['work']}") +print(f"\nFinal answer: {result['answer']}") +``` + +## Curl Examples + +### Non-streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}] + }' +``` + +### Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}], + "stream": true + }' +``` + +## Troubleshooting + +### No reasoning field in response + +- Make sure you started the server with `--reasoning-parser` +- Check that the model actually uses thinking tags (not all prompts trigger reasoning) + +### Reasoning appears in content + +- The model may not be using the expected tag format +- Try a different parser (`qwen3` vs `deepseek_r1`) + +### Truncated reasoning + +- Increase `--max-tokens` if the model is hitting the token limit mid-thought + +## Related + +- [Supported Models](../reference/models.md) - Models that support reasoning +- [Server Configuration](server.md) - All server options +- [CLI Reference](../reference/cli.md) - Command line options diff --git a/pyproject.toml b/pyproject.toml index 4d480d42..ce33c3b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,6 @@ mlx = "vllm_mlx.plugin:mlx_platform_plugin" [project.scripts] vllm-mlx = "vllm_mlx.cli:main" -vllm-mlx-serve = "vllm_mlx.server_v2:main" vllm-mlx-chat = "vllm_mlx.gradio_app:main" vllm-mlx-bench = "vllm_mlx.benchmark:main" @@ -121,6 +120,8 @@ target-version = ["py310", "py311", "py312", "py313"] [tool.ruff] line-length = 88 + +[tool.ruff.lint] select = ["E", "F", "W", "I", "N", "UP", "B", "SIM"] ignore = ["E501", "B905"] diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py new file mode 100644 index 00000000..ad539985 --- /dev/null +++ b/tests/test_reasoning_parser.py @@ -0,0 +1,681 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for reasoning content extraction parsers. + +Tests cover: +- Parser registry (registration, lookup, listing) +- Qwen3 parser (non-streaming and streaming) +- DeepSeek-R1 parser (non-streaming and streaming) +- Edge cases (no tags, partial tags, etc.) +""" + +import pytest + +from vllm_mlx.reasoning import ( + DeltaMessage, + ReasoningParser, + get_parser, + list_parsers, + register_parser, +) + + +class TestParserRegistry: + """Tests for the parser registry functions.""" + + def test_list_parsers_includes_builtin(self): + """Built-in parsers should be registered.""" + parsers = list_parsers() + assert "qwen3" in parsers + assert "deepseek_r1" in parsers + + def test_get_parser_qwen3(self): + """Should be able to get Qwen3 parser.""" + parser_cls = get_parser("qwen3") + parser = parser_cls() + assert isinstance(parser, ReasoningParser) + + def test_get_parser_deepseek(self): + """Should be able to get DeepSeek-R1 parser.""" + parser_cls = get_parser("deepseek_r1") + parser = parser_cls() + assert isinstance(parser, ReasoningParser) + + def test_get_unknown_parser_raises(self): + """Unknown parser name should raise KeyError.""" + with pytest.raises(KeyError) as exc_info: + get_parser("unknown_parser") + assert "unknown_parser" in str(exc_info.value) + assert "Available parsers" in str(exc_info.value) + + def test_register_custom_parser(self): + """Should be able to register custom parsers.""" + + class CustomParser(ReasoningParser): + def extract_reasoning(self, model_output): + return None, model_output + + def extract_reasoning_streaming(self, prev, curr, delta): + return DeltaMessage(content=delta) + + register_parser("custom_test", CustomParser) + assert "custom_test" in list_parsers() + + parser = get_parser("custom_test")() + assert isinstance(parser, CustomParser) + + +class TestQwen3Parser: + """Tests for the Qwen3 reasoning parser.""" + + @pytest.fixture + def parser(self): + """Create a fresh Qwen3 parser for each test.""" + return get_parser("qwen3")() + + # Non-streaming tests + + def test_extract_with_both_tags(self, parser): + """Should extract reasoning when both tags present.""" + output = "Let me analyze this problemThe answer is 42." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me analyze this problem" + assert content == "The answer is 42." + + def test_extract_only_reasoning(self, parser): + """Should handle case where only reasoning is present.""" + output = "Just thinking out loud" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Just thinking out loud" + assert content is None + + def test_extract_multiline_reasoning(self, parser): + """Should preserve newlines in reasoning content.""" + output = ( + "Step 1: Analyze\nStep 2: Solve\nStep 3: VerifyResult: 42" + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "Result: 42" + + def test_no_tags_returns_content_only(self, parser): + """Qwen3 requires both tags - no tags means pure content.""" + output = "Just a regular response without thinking." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_only_start_tag_no_reasoning(self, parser): + """Qwen3 requires both tags - missing end tag means no reasoning.""" + output = "Started thinking but never finished" + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_only_end_tag_implicit_mode(self, parser): + """Qwen3 supports implicit mode - when is in prompt, only in output.""" + output = "Some textmore text" + reasoning, content = parser.extract_reasoning(output) + # Implicit mode: everything before is reasoning + assert reasoning == "Some text" + assert content == "more text" + + # Streaming tests + + def test_streaming_simple_flow(self, parser): + """Test basic streaming with reasoning then content.""" + parser.reset_state() + + # Simulate streaming tokens + deltas = ["", "think", "ing", "", "answer"] + accumulated = "" + results = [] + + for delta in deltas: + prev = accumulated + accumulated += delta + result = parser.extract_reasoning_streaming(prev, accumulated, delta) + if result: + results.append(result) + + # Collect reasoning and content + reasoning_parts = [r.reasoning for r in results if r.reasoning] + content_parts = [r.content for r in results if r.content] + + assert "".join(reasoning_parts) == "thinking" + assert "".join(content_parts) == "answer" + + def test_streaming_skip_tags(self, parser): + """Special tokens themselves should be skipped.""" + parser.reset_state() + + # Just the start tag + result = parser.extract_reasoning_streaming("", "", "") + assert result is None + + # Just the end tag + result = parser.extract_reasoning_streaming( + "reasoning", "reasoning", "" + ) + assert result is None + + def test_streaming_transition_chunk(self, parser): + """Chunk containing end tag should split reasoning and content.""" + parser.reset_state() + + # Previous has start, delta contains end and content + prev = "reasoning" + delta = " morecontent here" + curr = prev + delta + + result = parser.extract_reasoning_streaming(prev, curr, delta) + + assert result is not None + assert result.reasoning == " more" + assert result.content == "content here" + + +class TestDeepSeekR1Parser: + """Tests for the DeepSeek-R1 reasoning parser.""" + + @pytest.fixture + def parser(self): + """Create a fresh DeepSeek-R1 parser for each test.""" + return get_parser("deepseek_r1")() + + # Non-streaming tests + + def test_extract_with_both_tags(self, parser): + """Should extract reasoning when both tags present.""" + output = "Step by step analysisFinal answer: 42" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Step by step analysis" + assert content == "Final answer: 42" + + def test_extract_implicit_start_tag(self, parser): + """DeepSeek-R1 handles implicit start tag (missing ).""" + output = "Implicit reasoning contentThe answer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Implicit reasoning content" + assert content == "The answer" + + def test_extract_no_tags_pure_content(self, parser): + """No tags should return pure content.""" + output = "Just a regular response." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_extract_multiline_reasoning(self, parser): + """Should preserve newlines in reasoning content.""" + output = "Line 1\nLine 2\nLine 3Result" + reasoning, content = parser.extract_reasoning(output) + assert "Line 1" in reasoning + assert "Line 2" in reasoning + assert "Line 3" in reasoning + assert content == "Result" + + # Streaming tests + + def test_streaming_simple_flow(self, parser): + """Test basic streaming with reasoning then content.""" + parser.reset_state() + + deltas = ["", "think", "ing", "", "answer"] + accumulated = "" + results = [] + + for delta in deltas: + prev = accumulated + accumulated += delta + result = parser.extract_reasoning_streaming(prev, accumulated, delta) + if result: + results.append(result) + + reasoning_parts = [r.reasoning for r in results if r.reasoning] + content_parts = [r.content for r in results if r.content] + + assert "".join(reasoning_parts) == "thinking" + assert "".join(content_parts) == "answer" + + +class TestDeltaMessage: + """Tests for the DeltaMessage dataclass.""" + + def test_reasoning_content_alias(self): + """reasoning_content should alias reasoning.""" + msg = DeltaMessage(reasoning="test reasoning") + assert msg.reasoning == "test reasoning" + assert msg.reasoning_content == "test reasoning" + + def test_content_only(self): + """Should handle content-only messages.""" + msg = DeltaMessage(content="just content") + assert msg.content == "just content" + assert msg.reasoning is None + assert msg.reasoning_content is None + + def test_both_fields(self): + """Should handle transition messages with both.""" + msg = DeltaMessage(reasoning="ending", content="starting") + assert msg.reasoning == "ending" + assert msg.content == "starting" + + +class TestEdgeCases: + """Test edge cases across parsers.""" + + @pytest.fixture(params=["qwen3", "deepseek_r1"]) + def parser(self, request): + """Parametrized fixture for both parsers.""" + return get_parser(request.param)() + + def test_empty_output(self, parser): + """Empty output should return (None, '').""" + reasoning, content = parser.extract_reasoning("") + # Either both None or content is empty string + assert reasoning is None or reasoning == "" + + def test_whitespace_only_reasoning(self, parser): + """Whitespace-only reasoning should be treated as None.""" + output = " content" + reasoning, content = parser.extract_reasoning(output) + # Whitespace-only should be stripped to None + if reasoning is not None: + assert reasoning.strip() == "" or reasoning is None + + def test_nested_tags_not_supported(self, parser): + """Nested tags are not officially supported - behavior may vary.""" + output = "outerinnerstill outercontent" + # Just ensure it doesn't crash + reasoning, content = parser.extract_reasoning(output) + # Result may vary by parser implementation + + def test_streaming_reset_state(self, parser): + """reset_state should allow reuse of parser.""" + # First stream + parser.reset_state() + parser.extract_reasoning_streaming("", "", "") + + # Reset for new stream + parser.reset_state() + + # Should work fresh + result = parser.extract_reasoning_streaming("", "content", "content") + assert result is not None + + +class TestRealisticStreaming: + """Tests for realistic streaming scenarios simulating actual model output.""" + + @pytest.fixture(params=["qwen3", "deepseek_r1"]) + def parser(self, request): + """Parametrized fixture for both parsers.""" + return get_parser(request.param)() + + def test_token_by_token_streaming(self, parser): + """Simulate realistic token-by-token streaming.""" + # Typical model output broken into tokens + tokens = [ + "<", + "think", + ">", # Start tag split across tokens + "Let", + " me", + " analyze", + " this", + ".", + "\n", + "Step", + " 1", + ":", + " check", + " input", + "\n", + "Step", + " 2", + ":", + " compute", + "", # End tag split across tokens + "The", + " answer", + " is", + " 42", + ".", + ] + + parser.reset_state() + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + # Verify reasoning was captured + assert "Let me analyze" in full_reasoning + assert "Step 1" in full_reasoning + assert "Step 2" in full_reasoning + + # Verify content was captured + assert "The answer is 42" in full_content + + def test_long_reasoning_streaming(self, parser): + """Test streaming with extended reasoning.""" + # Long reasoning content + reasoning_text = """ + First, I need to understand the problem. + The user is asking about quantum computing. + + Let me break this down: + 1. Quantum bits (qubits) can be in superposition + 2. Entanglement allows correlated states + 3. Quantum gates perform operations + + After careful analysis, I can provide an answer. + """ + + output = f"{reasoning_text}Quantum computing uses qubits." + + # Simulate character-by-character streaming + parser.reset_state() + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for char in output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "quantum computing" in full_reasoning.lower() + assert "qubits" in full_reasoning.lower() + assert "Quantum computing uses qubits" in full_content + + def test_streaming_no_content_after_reasoning(self, parser): + """Test streaming when there's only reasoning, no content.""" + tokens = ["", "just", " thinking", ""] + + parser.reset_state() + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + assert "just thinking" in "".join(reasoning_parts) + assert len(content_parts) == 0 or "".join(content_parts).strip() == "" + + +class TestUnicodeAndSpecialCharacters: + """Tests for Unicode and special characters in reasoning.""" + + @pytest.fixture(params=["qwen3", "deepseek_r1"]) + def parser(self, request): + """Parametrized fixture for both parsers.""" + return get_parser(request.param)() + + def test_unicode_reasoning(self, parser): + """Test reasoning with Unicode characters.""" + output = "分析这个问题:日本語テスト émojis: 🤔💭答案是42" + reasoning, content = parser.extract_reasoning(output) + assert "分析" in reasoning + assert "日本語" in reasoning + assert "🤔" in reasoning + assert "42" in content + + def test_code_in_reasoning(self, parser): + """Test reasoning containing code snippets.""" + output = """ +Let me analyze the code: +```python +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n-1) +``` +This is a recursive implementation. +The factorial function uses recursion.""" + + reasoning, content = parser.extract_reasoning(output) + assert "def factorial" in reasoning + assert "recursive" in reasoning + assert "uses recursion" in content + + def test_html_like_content(self, parser): + """Test that HTML-like content doesn't confuse the parser.""" + output = "The user mentioned
and tagsUse CSS for styling." + reasoning, content = parser.extract_reasoning(output) + assert "
" in reasoning + assert "" in reasoning + assert "CSS" in content + + def test_math_expressions(self, parser): + """Test reasoning with mathematical expressions.""" + output = "Given: x² + 2x + 1 = 0, so (x+1)² = 0, x = -1x = -1" + reasoning, content = parser.extract_reasoning(output) + assert "x²" in reasoning + assert "(x+1)²" in reasoning + assert "-1" in content + + +class TestAPIModelsIntegration: + """Tests for integration with API models.""" + + def test_assistant_message_with_reasoning(self): + """Test that AssistantMessage can hold reasoning content.""" + from vllm_mlx.api.models import AssistantMessage + + msg = AssistantMessage( + content="The answer is 42.", reasoning="Let me think step by step..." + ) + assert msg.content == "The answer is 42." + assert msg.reasoning == "Let me think step by step..." + assert msg.role == "assistant" + + def test_assistant_message_reasoning_none(self): + """Test AssistantMessage with no reasoning.""" + from vllm_mlx.api.models import AssistantMessage + + msg = AssistantMessage(content="Simple response without reasoning.") + assert msg.content == "Simple response without reasoning." + assert msg.reasoning is None + + def test_chat_completion_chunk_delta_with_reasoning(self): + """Test that ChatCompletionChunkDelta can hold reasoning.""" + from vllm_mlx.api.models import ChatCompletionChunkDelta + + delta = ChatCompletionChunkDelta(reasoning="thinking...") + assert delta.reasoning == "thinking..." + assert delta.content is None + + delta2 = ChatCompletionChunkDelta(content="response text") + assert delta2.content == "response text" + assert delta2.reasoning is None + + def test_delta_transition(self): + """Test delta during transition from reasoning to content.""" + from vllm_mlx.api.models import ChatCompletionChunkDelta + + # During transition, both might have values + delta = ChatCompletionChunkDelta( + reasoning="final thought", content="starting answer" + ) + assert delta.reasoning == "final thought" + assert delta.content == "starting answer" + + +class TestParserPerformance: + """Basic performance tests for parsers.""" + + @pytest.fixture(params=["qwen3", "deepseek_r1"]) + def parser(self, request): + """Parametrized fixture for both parsers.""" + return get_parser(request.param)() + + def test_large_output_extraction(self, parser): + """Test extraction from large output.""" + # Generate large reasoning content + reasoning_lines = [f"Step {i}: processing data chunk {i}" for i in range(100)] + reasoning_text = "\n".join(reasoning_lines) + output = f"{reasoning_text}Processing complete." + + reasoning, content = parser.extract_reasoning(output) + + assert reasoning is not None + assert "Step 0" in reasoning + assert "Step 99" in reasoning + assert content == "Processing complete." + + def test_streaming_many_chunks(self, parser): + """Test streaming with many small chunks.""" + parser.reset_state() + + # Generate many small chunks + base_output = "A" * 100 + "" + "B" * 50 + accumulated = "" + chunk_count = 0 + + for char in base_output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + chunk_count += 1 + + # Should have processed all characters + assert chunk_count > 0 + + def test_repeated_parsing(self, parser): + """Test parsing same output multiple times.""" + output = "Quick thoughtQuick answer" + + for _ in range(100): + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Quick thought" + assert content == "Quick answer" + + +class TestDeepSeekSpecificCases: + """Tests specific to DeepSeek-R1 parser behavior.""" + + @pytest.fixture + def parser(self): + """Create DeepSeek-R1 parser.""" + return get_parser("deepseek_r1")() + + def test_implicit_reasoning_streaming(self, parser): + """Test streaming when start tag is implicit (DeepSeek-R1 specific).""" + # DeepSeek-R1 sometimes omits but includes + tokens = ["reasoning", " text", " here", "", "answer"] + + parser.reset_state() + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + # For DeepSeek-R1, content before without is treated as content + # until appears in the delta + all_parts = reasoning_parts + content_parts + assert len(all_parts) > 0 + + def test_deepseek_long_implicit_reasoning(self, parser): + """Test long implicit reasoning without start tag.""" + output = """Let me think about this problem carefully. + +First, I need to consider the constraints. +Then, I'll apply the algorithm. +Finally, I'll verify the result.The answer is 42.""" + + reasoning, content = parser.extract_reasoning(output) + assert reasoning is not None + assert "think about this problem" in reasoning + assert "42" in content + + +class TestQwen3SpecificCases: + """Tests specific to Qwen3 parser behavior.""" + + @pytest.fixture + def parser(self): + """Create Qwen3 parser.""" + return get_parser("qwen3")() + + def test_qwen3_implicit_mode_support(self, parser): + """Qwen3 supports implicit mode for OpenCode compatibility.""" + # Only end tag - implicit mode (think injected in prompt) + output1 = "some textmore text" + reasoning, content = parser.extract_reasoning(output1) + # Implicit mode: everything before is reasoning + assert reasoning == "some text" + assert content == "more text" + + # Only start tag - no means model is still generating + # Qwen3 requires to extract reasoning (treats as pure content until then) + output2 = "incomplete reasoning" + reasoning, content = parser.extract_reasoning(output2) + # No = no reasoning extraction, entire output is content + assert reasoning is None + assert content == output2 + + def test_qwen3_empty_think_tags(self, parser): + """Test empty think tags.""" + output = "Just the answer." + reasoning, content = parser.extract_reasoning(output) + # Empty reasoning should be None + assert reasoning is None or reasoning.strip() == "" + assert content == "Just the answer." + + def test_qwen3_whitespace_between_tags(self, parser): + """Test various whitespace patterns.""" + test_cases = [ + (" answer", None, "answer"), + ("\n\nanswer", None, "answer"), + ("\t\tanswer", None, "answer"), + ] + + for output, expected_reasoning, expected_content in test_cases: + reasoning, content = parser.extract_reasoning(output) + if expected_reasoning is None: + assert reasoning is None or reasoning.strip() == "" + assert expected_content in (content or "") diff --git a/tests/test_streaming_latency.py b/tests/test_streaming_latency.py index 2219d1e1..cae95f5f 100644 --- a/tests/test_streaming_latency.py +++ b/tests/test_streaming_latency.py @@ -10,7 +10,7 @@ Usage: # Start server first: - python -m vllm_mlx.server_v2 --model mlx-community/Llama-3.2-3B-Instruct-4bit + vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit # Run test: python tests/test_streaming_latency.py diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index 789bf65a..2884e3a3 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -680,3 +680,63 @@ def test_auto_streaming(self): delta_text="Hello world", ) assert result == {"content": "Hello world"} + + +class TestThinkTagStripping: + """Test tag stripping in tool parsers (Issue #26).""" + + def test_strip_think_tags_utility(self): + """Test the strip_think_tags static method.""" + from vllm_mlx.tool_parsers.abstract_tool_parser import ToolParser + + # Basic stripping + text = "Let me analyze thisThe answer is 42" + assert ToolParser.strip_think_tags(text) == "The answer is 42" + + # Multi-line thinking + text = "Step 1\nStep 2\nStep 3Result" + assert ToolParser.strip_think_tags(text) == "Result" + + # No think tags + text = "Just regular text" + assert ToolParser.strip_think_tags(text) == "Just regular text" + + # Empty think tags + text = "Content" + assert ToolParser.strip_think_tags(text) == "Content" + + def test_hermes_with_think_tags(self): + """Test Hermes parser strips think tags before parsing tool calls.""" + parser = HermesToolParser() + + # Model output with think tags AND tool call (Ring-Mini-Linear-2.0 style) + output = """Let me search for that information. +{"name": "search", "arguments": {"query": "weather"}}""" + + result = parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "search" + + def test_qwen_with_think_tags(self): + """Test Qwen parser strips think tags before parsing tool calls.""" + parser = QwenToolParser() + + # Model output with think tags AND tool call + output = """I need to get the weather data. +[Calling tool: get_weather({"city": "Tokyo"})]""" + + result = parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + + def test_think_tags_with_no_tool_call(self): + """Test that think tags are stripped even when no tool call is present.""" + parser = HermesToolParser() + + output = "Let me think about thisThe answer is 42." + result = parser.extract_tool_calls(output) + + assert result.tools_called is False + assert result.content == "The answer is 42." diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index b8ce7342..34822a58 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -11,9 +11,8 @@ import time import uuid -from typing import List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field # ============================================================================= # Content Types (for multimodal messages) @@ -24,7 +23,7 @@ class ImageUrl(BaseModel): """Image URL with optional detail level.""" url: str - detail: Optional[str] = None + detail: str | None = None class VideoUrl(BaseModel): @@ -52,11 +51,11 @@ class ContentPart(BaseModel): """ type: str # "text", "image_url", "video", "video_url", "audio_url" - text: Optional[str] = None - image_url: Optional[Union[ImageUrl, dict, str]] = None - video: Optional[str] = None - video_url: Optional[Union[VideoUrl, dict, str]] = None - audio_url: Optional[Union[AudioUrl, dict, str]] = None + text: str | None = None + image_url: ImageUrl | dict | str | None = None + video: str | None = None + video_url: VideoUrl | dict | str | None = None + audio_url: AudioUrl | dict | str | None = None # ============================================================================= @@ -76,11 +75,11 @@ class Message(BaseModel): """ role: str - content: Optional[Union[str, List[ContentPart], List[dict]]] = None + content: str | list[ContentPart] | list[dict] | None = None # For assistant messages with tool calls - tool_calls: Optional[List[dict]] = None + tool_calls: list[dict] | None = None # For tool response messages (role="tool") - tool_call_id: Optional[str] = None + tool_call_id: str | None = None # ============================================================================= @@ -119,9 +118,9 @@ class ResponseFormatJsonSchema(BaseModel): """JSON Schema definition for structured output.""" name: str - description: Optional[str] = None + description: str | None = None schema_: dict = Field(alias="schema") # JSON Schema specification - strict: Optional[bool] = False + strict: bool | None = False class Config: populate_by_name = True @@ -138,7 +137,7 @@ class ResponseFormat(BaseModel): """ type: str = "text" # "text", "json_object", "json_schema" - json_schema: Optional[ResponseFormatJsonSchema] = None + json_schema: ResponseFormatJsonSchema | None = None # ============================================================================= @@ -156,33 +155,42 @@ class ChatCompletionRequest(BaseModel): """Request for chat completion.""" model: str - messages: List[Message] + messages: list[Message] temperature: float = 0.7 top_p: float = 0.9 - max_tokens: Optional[int] = None + max_tokens: int | None = None stream: bool = False - stream_options: Optional[StreamOptions] = ( + stream_options: StreamOptions | None = ( None # Streaming options (include_usage, etc.) ) - stop: Optional[List[str]] = None + stop: list[str] | None = None # Tool calling - tools: Optional[List[ToolDefinition]] = None - tool_choice: Optional[Union[str, dict]] = None # "auto", "none", or specific tool + tools: list[ToolDefinition] | None = None + tool_choice: str | dict | None = None # "auto", "none", or specific tool # Structured output - response_format: Optional[Union[ResponseFormat, dict]] = None + response_format: ResponseFormat | dict | None = None # MLLM-specific parameters - video_fps: Optional[float] = None - video_max_frames: Optional[int] = None + video_fps: float | None = None + video_max_frames: int | None = None # Request timeout in seconds (None = use server default) - timeout: Optional[float] = None + timeout: float | None = None class AssistantMessage(BaseModel): """Response message from the assistant.""" role: str = "assistant" - content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = None + content: str | None = None + reasoning: str | None = ( + None # Reasoning/thinking content (when --reasoning-parser is used) + ) + tool_calls: list[ToolCall] | None = None + + @computed_field + @property + def reasoning_content(self) -> str | None: + """Alias for reasoning field. Serialized for backwards compatibility with clients expecting reasoning_content.""" + return self.reasoning class ChatCompletionChoice(BaseModel): @@ -190,7 +198,7 @@ class ChatCompletionChoice(BaseModel): index: int = 0 message: AssistantMessage - finish_reason: Optional[str] = "stop" + finish_reason: str | None = "stop" class Usage(BaseModel): @@ -208,7 +216,7 @@ class ChatCompletionResponse(BaseModel): object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionChoice] + choices: list[ChatCompletionChoice] usage: Usage = Field(default_factory=Usage) @@ -221,14 +229,14 @@ class CompletionRequest(BaseModel): """Request for text completion.""" model: str - prompt: Union[str, List[str]] + prompt: str | list[str] temperature: float = 0.7 top_p: float = 0.9 - max_tokens: Optional[int] = None + max_tokens: int | None = None stream: bool = False - stop: Optional[List[str]] = None + stop: list[str] | None = None # Request timeout in seconds (None = use server default) - timeout: Optional[float] = None + timeout: float | None = None class CompletionChoice(BaseModel): @@ -236,7 +244,7 @@ class CompletionChoice(BaseModel): index: int = 0 text: str - finish_reason: Optional[str] = "stop" + finish_reason: str | None = "stop" class CompletionResponse(BaseModel): @@ -246,7 +254,7 @@ class CompletionResponse(BaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionChoice] + choices: list[CompletionChoice] usage: Usage = Field(default_factory=Usage) @@ -268,7 +276,7 @@ class ModelsResponse(BaseModel): """Response for listing models.""" object: str = "list" - data: List[ModelInfo] + data: list[ModelInfo] # ============================================================================= @@ -288,7 +296,7 @@ class MCPToolInfo(BaseModel): class MCPToolsResponse(BaseModel): """Response for listing MCP tools.""" - tools: List[MCPToolInfo] + tools: list[MCPToolInfo] count: int @@ -299,13 +307,13 @@ class MCPServerInfo(BaseModel): state: str transport: str tools_count: int - error: Optional[str] = None + error: str | None = None class MCPServersResponse(BaseModel): """Response for listing MCP servers.""" - servers: List[MCPServerInfo] + servers: list[MCPServerInfo] class MCPExecuteRequest(BaseModel): @@ -319,9 +327,9 @@ class MCPExecuteResponse(BaseModel): """Response from executing an MCP tool.""" tool_name: str - content: Optional[Union[str, list, dict]] = None + content: str | list | dict | None = None is_error: bool = False - error_message: Optional[str] = None + error_message: str | None = None # ============================================================================= @@ -333,19 +341,19 @@ class AudioTranscriptionRequest(BaseModel): """Request for audio transcription (STT).""" model: str = "whisper-large-v3" - language: Optional[str] = None + language: str | None = None response_format: str = "json" temperature: float = 0.0 - timestamp_granularities: Optional[List[str]] = None + timestamp_granularities: list[str] | None = None class AudioTranscriptionResponse(BaseModel): """Response from audio transcription.""" text: str - language: Optional[str] = None - duration: Optional[float] = None - segments: Optional[List[dict]] = None + language: str | None = None + duration: float | None = None + segments: list[dict] | None = None class AudioSpeechRequest(BaseModel): @@ -362,7 +370,7 @@ class AudioSeparationRequest(BaseModel): """Request for audio source separation.""" model: str = "htdemucs" - stems: List[str] = Field(default_factory=lambda: ["vocals", "accompaniment"]) + stems: list[str] = Field(default_factory=lambda: ["vocals", "accompaniment"]) # ============================================================================= @@ -373,9 +381,18 @@ class AudioSeparationRequest(BaseModel): class ChatCompletionChunkDelta(BaseModel): """Delta content in a streaming chunk.""" - role: Optional[str] = None - content: Optional[str] = None - tool_calls: Optional[List[dict]] = None + role: str | None = None + content: str | None = None + reasoning: str | None = ( + None # Reasoning/thinking content (when --reasoning-parser is used) + ) + tool_calls: list[dict] | None = None + + @computed_field + @property + def reasoning_content(self) -> str | None: + """Alias for reasoning field. Serialized for backwards compatibility with clients expecting reasoning_content.""" + return self.reasoning class ChatCompletionChunkChoice(BaseModel): @@ -383,7 +400,7 @@ class ChatCompletionChunkChoice(BaseModel): index: int = 0 delta: ChatCompletionChunkDelta - finish_reason: Optional[str] = None + finish_reason: str | None = None class ChatCompletionChunk(BaseModel): @@ -393,5 +410,5 @@ class ChatCompletionChunk(BaseModel): object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionChunkChoice] + choices: list[ChatCompletionChunkChoice] usage: Usage | None = None # Included when stream_options.include_usage=true diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index d5718247..571113a3 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -53,6 +53,29 @@ def serve_command(args): server._enable_auto_tool_choice = False server._tool_call_parser = None + # Configure reasoning parser + if args.reasoning_parser: + try: + from .reasoning import get_parser + + parser_cls = get_parser(args.reasoning_parser) + server._reasoning_parser = parser_cls() + logger.info(f"Reasoning parser enabled: {args.reasoning_parser}") + except KeyError as e: + print(f"Error: {e}") + sys.exit(1) + except ImportError as e: + print(f"Error: Failed to import reasoning module: {e}") + sys.exit(1) + except Exception as e: + print( + f"Error: Failed to initialize reasoning parser " + f"'{args.reasoning_parser}': {e}" + ) + sys.exit(1) + else: + server._reasoning_parser = None + # Security summary at startup print("=" * 60) print("SECURITY CONFIGURATION") @@ -70,6 +93,10 @@ def serve_command(args): print(f" Tool calling: ENABLED (parser: {args.tool_call_parser})") else: print(" Tool calling: Use --enable-auto-tool-choice to enable") + if args.reasoning_parser: + print(f" Reasoning: ENABLED (parser: {args.reasoning_parser})") + else: + print(" Reasoning: Use --reasoning-parser to enable") print("=" * 60) print(f"Loading model: {args.model}") @@ -519,6 +546,21 @@ def main(): "Required for --enable-auto-tool-choice." ), ) + # Reasoning parser options - choices loaded dynamically from registry + from .reasoning import list_parsers + + reasoning_choices = list_parsers() + serve_parser.add_argument( + "--reasoning-parser", + type=str, + default=None, + choices=reasoning_choices, + help=( + "Enable reasoning content extraction with specified parser. " + "Extracts ... tags into reasoning_content field. " + f"Options: {', '.join(reasoning_choices)}." + ), + ) # Bench command bench_parser = subparsers.add_parser("bench", help="Run benchmark") diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py new file mode 100644 index 00000000..eef9c62f --- /dev/null +++ b/vllm_mlx/reasoning/__init__.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser module for vllm-mlx. + +This module provides parsers for extracting reasoning/thinking content from +model outputs. Supports models like Qwen3, DeepSeek-R1, etc. that use special +tokens (e.g., ...) to separate reasoning from final responses. + +Usage: + from vllm_mlx.reasoning import get_parser, list_parsers + + # Get a parser by name + parser = get_parser("qwen3")() + + # Extract reasoning from complete output + reasoning, content = parser.extract_reasoning(model_output) + + # For streaming + parser.reset_state() + for delta in stream: + msg = parser.extract_reasoning_streaming(prev, curr, delta) + if msg: + # msg.reasoning and/or msg.content will be populated + ... +""" + +from .base import DeltaMessage, ReasoningParser +from .think_parser import BaseThinkingReasoningParser + +# Parser registry +_REASONING_PARSERS: dict[str, type[ReasoningParser]] = {} + + +def register_parser(name: str, parser_class: type[ReasoningParser]) -> None: + """ + Register a reasoning parser. + + Args: + name: Name to register the parser under (e.g., "qwen3"). + parser_class: The parser class to register. + """ + _REASONING_PARSERS[name] = parser_class + + +def get_parser(name: str) -> type[ReasoningParser]: + """ + Get a reasoning parser class by name. + + Args: + name: Name of the parser (e.g., "qwen3", "deepseek_r1"). + + Returns: + The parser class (not an instance). + + Raises: + KeyError: If parser name is not found. + """ + if name not in _REASONING_PARSERS: + available = list(_REASONING_PARSERS.keys()) + raise KeyError( + f"Reasoning parser '{name}' not found. Available parsers: {available}" + ) + return _REASONING_PARSERS[name] + + +def list_parsers() -> list[str]: + """ + List available parser names. + + Returns: + List of registered parser names. + """ + return list(_REASONING_PARSERS.keys()) + + +def _register_builtin_parsers(): + """Register built-in parsers.""" + from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .qwen3_parser import Qwen3ReasoningParser + + register_parser("qwen3", Qwen3ReasoningParser) + register_parser("deepseek_r1", DeepSeekR1ReasoningParser) + + +# Register built-in parsers on module load +_register_builtin_parsers() + + +__all__ = [ + # Base classes + "ReasoningParser", + "DeltaMessage", + "BaseThinkingReasoningParser", + # Registry functions + "register_parser", + "get_parser", + "list_parsers", +] diff --git a/vllm_mlx/reasoning/base.py b/vllm_mlx/reasoning/base.py new file mode 100644 index 00000000..aaefef9c --- /dev/null +++ b/vllm_mlx/reasoning/base.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Base classes for reasoning content extraction. + +This module provides the abstract base class for reasoning parsers that extract +thinking/reasoning content from model outputs (e.g., ... tags). +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +@dataclass +class DeltaMessage: + """ + Delta message for streaming reasoning output. + + Contains either reasoning content, regular content, or both when + transitioning from reasoning to content phase. + + Note: reasoning and content should typically not both be non-None + except during the transition chunk. + """ + + role: str | None = None + content: str | None = None + reasoning: str | None = None + + @property + def reasoning_content(self) -> str | None: + """Deprecated: use reasoning instead. Maintained for backward compatibility.""" + return self.reasoning + + +class ReasoningParser(ABC): + """ + Abstract base class for reasoning content extraction. + + Reasoning parsers extract thinking/reasoning content from model outputs, + separating it from the final response content. This is useful for models + like DeepSeek-R1, Qwen3, etc. that use special tokens to denote reasoning. + + Example: + Input: "Let me solve this step by step...The answer is 42." + Output: reasoning="Let me solve this step by step...", content="The answer is 42." + """ + + def __init__(self, tokenizer: Any | None = None): + """ + Initialize parser with optional tokenizer. + + Args: + tokenizer: Optional tokenizer for token-based parsing. For vllm-mlx, + text-based parsing is sufficient, so this is optional. + """ + self.tokenizer = tokenizer + + @abstractmethod + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning content from complete model output. + + Args: + model_output: Complete text output from the model. + + Returns: + Tuple of (reasoning_content, final_content). + Either may be None if not present. + """ + pass + + @abstractmethod + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Uses the "previous + delta = current" model where: + - previous_text: All text accumulated before this delta + - current_text: All text including this delta (previous + delta) + - delta_text: Just the new text in this chunk + + Args: + previous_text: Accumulated text before this delta. + current_text: Accumulated text including this delta. + delta_text: The new text in this streaming chunk. + + Returns: + DeltaMessage with reasoning and/or content populated, + or None if this delta should be skipped (e.g., special tokens). + """ + pass + + def reset_state(self): # noqa: B027 + """ + Reset any internal state for a new request. + + Called before starting to process a new streaming request. + Override in subclasses if stateful parsing is needed. + This is intentionally a default no-op implementation. + """ + pass diff --git a/vllm_mlx/reasoning/deepseek_r1_parser.py b/vllm_mlx/reasoning/deepseek_r1_parser.py new file mode 100644 index 00000000..b633781d --- /dev/null +++ b/vllm_mlx/reasoning/deepseek_r1_parser.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for DeepSeek-R1 models. + +DeepSeek-R1 uses ... tags for reasoning content. +The model may sometimes start outputting reasoning without the explicit + tag, so this parser is more lenient than Qwen3. +""" + +from .base import DeltaMessage +from .think_parser import BaseThinkingReasoningParser + + +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for DeepSeek-R1 model. + + DeepSeek-R1 uses ... tokens to denote reasoning text. + This parser is more lenient than Qwen3: + - The tag may not be explicitly generated (model assumes it) + - If only is found, everything before it is reasoning + + Example: + Input: "Step 1: analyze...\nStep 2: solve...The answer is 42." + Output: reasoning="Step 1: analyze...\nStep 2: solve...", content="The answer is 42." + + Input: "reasoning contentfinal answer" # No opening tag + Output: reasoning="reasoning content", content="final answer" + """ + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from DeepSeek-R1 output. + + More lenient than Qwen3 - handles cases where start tag is implicit. + + Args: + model_output: Complete model output text. + + Returns: + (reasoning, content) tuple. + """ + # If we have end token but no start token, treat beginning as reasoning + if self.end_token in model_output and self.start_token not in model_output: + reasoning, _, content = model_output.partition(self.end_token) + reasoning = reasoning.strip() or None + content = content.strip() or None + return reasoning, content + + # If neither token, return as pure content + if self.end_token not in model_output and self.start_token not in model_output: + return None, model_output + + # Use base class for standard case + return super().extract_reasoning(model_output) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Handles DeepSeek-R1's pattern where may be implicit. + + Args: + previous_text: Text accumulated before this delta. + current_text: Text including this delta. + delta_text: Just the new text. + + Returns: + DeltaMessage with reasoning/content, or None to skip. + """ + # First try base class logic + result = super().extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + # Handle DeepSeek-R1 special case: no start token seen but end token appears + if result is not None: + start_in_prev = self.start_token in previous_text + start_in_delta = self.start_token in delta_text + end_in_delta = self.end_token in delta_text + + # If end token in delta but we never saw start token + if not start_in_prev and not start_in_delta and end_in_delta: + # Everything before end token is reasoning + idx = delta_text.find(self.end_token) + reasoning_part = delta_text[:idx] + content_part = delta_text[idx + len(self.end_token) :] + return DeltaMessage( + reasoning=reasoning_part if reasoning_part else None, + content=content_part if content_part else None, + ) + + # Note: DeepSeek-R1 may omit but still be in reasoning mode. + # However, we can't reliably detect implicit reasoning without context, + # so we default to treating unmarked content as regular content. + + return result diff --git a/vllm_mlx/reasoning/qwen3_parser.py b/vllm_mlx/reasoning/qwen3_parser.py new file mode 100644 index 00000000..cb729d1f --- /dev/null +++ b/vllm_mlx/reasoning/qwen3_parser.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for Qwen3 models. + +Qwen3 uses ... tags for reasoning content and supports +a strict switch via 'enable_thinking=False' in chat template kwargs. + +Supports implicit reasoning mode where is injected in the prompt +by AI agents (e.g., OpenCode) and only appears in the output. +""" + +from .think_parser import BaseThinkingReasoningParser + + +class Qwen3ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Qwen3 models. + + Qwen3 uses ... tokens to denote reasoning text. + + Supports three scenarios: + 1. Both tags in output: reasoningcontent + 2. Only closing tag (think in prompt): reasoningcontent + 3. No tags: pure content + + Example (normal): + Input: "Let me analyze this...The answer is 42." + Output: reasoning="Let me analyze this...", content="The answer is 42." + + Example (think in prompt): + Input: "Let me analyze this...The answer is 42." + Output: reasoning="Let me analyze this...", content="The answer is 42." + """ + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from Qwen3 output. + + Handles both explicit ... tags and implicit mode + where was in the prompt (only in output). + + Args: + model_output: Complete model output text. + + Returns: + (reasoning, content) tuple. + """ + # If no end token at all, treat as pure content + if self.end_token not in model_output: + return None, model_output + + # Use base class implementation (handles both explicit and implicit) + return super().extract_reasoning(model_output) diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py new file mode 100644 index 00000000..13634820 --- /dev/null +++ b/vllm_mlx/reasoning/think_parser.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Base parser for models using ... tags for reasoning. + +This module provides BaseThinkingReasoningParser, a concrete implementation +for extracting reasoning content from models that use thinking tags. + +Supports three scenarios: +1. Both tags in output: reasoningcontent +2. Only closing tag (think injected in prompt): reasoningcontent +3. No tags: pure content +""" + +from abc import abstractmethod + +from .base import DeltaMessage, ReasoningParser + + +class BaseThinkingReasoningParser(ReasoningParser): + """ + Base parser for models using ... style tags. + + This parser handles the common pattern where reasoning content is wrapped + in special tags. Subclasses define the specific start and end tokens. + + Supports "implicit reasoning mode" where is injected in the prompt + and only appears in the model output. This is common with AI agents + like OpenCode that force models to reason by injecting thinking tags. + + The parser tracks state during streaming to correctly separate reasoning + from content as tokens arrive incrementally. + """ + + @property + @abstractmethod + def start_token(self) -> str: + """The token/tag that starts reasoning content (e.g., '').""" + + @property + @abstractmethod + def end_token(self) -> str: + """The token/tag that ends reasoning content (e.g., '').""" + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete output. + + Handles three cases: + 1. Both tags present: reasoningcontent + 2. Only closing tag: reasoningcontent (think in prompt) + 3. No tags: pure content + + Args: + model_output: Complete model output text. + + Returns: + (reasoning, content) tuple. Either may be None. + """ + text = model_output + + # Case 1: Both tags present (normal case) + if self.start_token in text and self.end_token in text: + # Get everything after start token + _, _, after_start = text.partition(self.start_token) + # Split on end token + reasoning, _, content = after_start.partition(self.end_token) + return reasoning.strip() or None, content.strip() or None + + # Case 2: Only closing tag (think was injected in prompt) + # Everything before is reasoning + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + return reasoning.strip() or None, content.strip() or None + + # Case 3: Only start tag (incomplete reasoning, no end yet) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + return reasoning.strip() or None, None + + # Case 4: No tags at all - pure content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta using text-based detection. + + Handles implicit reasoning mode where was in the prompt + and only appears in the output. + + Args: + previous_text: Text accumulated before this delta. + current_text: Text including this delta. + delta_text: Just the new text. + + Returns: + DeltaMessage with reasoning/content, or None to skip. + """ + # Skip if delta is just the special tokens themselves + stripped_delta = delta_text.strip() + if stripped_delta == self.start_token: + return None + if stripped_delta == self.end_token: + return None + + # Check token positions in text (stateless text-based detection) + start_in_prev = self.start_token in previous_text + start_in_current = self.start_token in current_text + end_in_prev = self.end_token in previous_text + end_in_delta = self.end_token in delta_text + + # Case 1: Explicit found in text - standard behavior + if start_in_current: + return self._handle_explicit_think( + previous_text, delta_text, start_in_prev, end_in_prev, end_in_delta + ) + + # Case 2: No but found - implicit reasoning mode + # This handles when was injected in the prompt + if self.end_token in current_text: + return self._handle_implicit_think(delta_text, end_in_prev, end_in_delta) + + # Case 3: No think tags seen yet + # We can't know if was in the prompt, so we must make a choice: + # - Treat as content (safe, but loses reasoning if think was in prompt) + # - Treat as reasoning (risky, wrong if no thinking at all) + # We choose to treat as reasoning IF we haven't seen yet, + # because if think was in prompt, we want to capture the reasoning. + # This will be corrected once is seen. + return DeltaMessage(reasoning=delta_text) + + def _handle_explicit_think( + self, + previous_text: str, + delta_text: str, + start_in_prev: bool, + end_in_prev: bool, + end_in_delta: bool, + ) -> DeltaMessage | None: + """Handle case where tag is explicitly in the output.""" + start_in_delta = self.start_token in delta_text + + if start_in_prev: + # We're after the start token + if end_in_delta: + # Transition: end token in this delta + idx = delta_text.find(self.end_token) + reasoning_part = delta_text[:idx] + content_part = delta_text[idx + len(self.end_token) :] + return DeltaMessage( + reasoning=reasoning_part if reasoning_part else None, + content=content_part if content_part else None, + ) + elif end_in_prev: + # Already past reasoning phase - pure content + return DeltaMessage(content=delta_text) + else: + # Still in reasoning phase + return DeltaMessage(reasoning=delta_text) + + elif start_in_delta: + # Start token is in this delta + start_idx = delta_text.find(self.start_token) + + if end_in_delta: + # Both tokens in this delta + end_idx = delta_text.find(self.end_token) + reasoning_part = delta_text[start_idx + len(self.start_token) : end_idx] + content_part = delta_text[end_idx + len(self.end_token) :] + return DeltaMessage( + reasoning=reasoning_part if reasoning_part else None, + content=content_part if content_part else None, + ) + else: + # Only start token - beginning of reasoning + reasoning_part = delta_text[start_idx + len(self.start_token) :] + return DeltaMessage( + reasoning=reasoning_part if reasoning_part else None + ) + + # Fallback - treat as content + return DeltaMessage(content=delta_text) + + def _handle_implicit_think( + self, + delta_text: str, + end_in_prev: bool, + end_in_delta: bool, + ) -> DeltaMessage | None: + """Handle case where was in prompt (only in output).""" + if end_in_delta: + # Transition: end token in this delta + idx = delta_text.find(self.end_token) + reasoning_part = delta_text[:idx] + content_part = delta_text[idx + len(self.end_token) :] + return DeltaMessage( + reasoning=reasoning_part if reasoning_part else None, + content=content_part if content_part else None, + ) + elif end_in_prev: + # Already past reasoning phase - pure content + return DeltaMessage(content=delta_text) + else: + # Still in implicit reasoning phase + return DeltaMessage(reasoning=delta_text) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index cd37165f..1acb788d 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1241,12 +1241,19 @@ def main(): default=0, help="Rate limit requests per minute per client (0 = disabled)", ) + # Reasoning parser options - choices loaded dynamically from registry + from .reasoning import list_parsers + + reasoning_choices = list_parsers() parser.add_argument( "--reasoning-parser", type=str, default=None, - choices=["qwen3", "deepseek_r1"], - help="Enable reasoning content extraction with specified parser", + choices=reasoning_choices, + help=( + "Enable reasoning content extraction with specified parser. " + f"Options: {', '.join(reasoning_choices)}." + ), ) args = parser.parse_args() diff --git a/vllm_mlx/tool_parsers/abstract_tool_parser.py b/vllm_mlx/tool_parsers/abstract_tool_parser.py index 8648b7d8..a76f487e 100644 --- a/vllm_mlx/tool_parsers/abstract_tool_parser.py +++ b/vllm_mlx/tool_parsers/abstract_tool_parser.py @@ -6,6 +6,7 @@ """ import importlib +import re from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -14,6 +15,13 @@ from transformers import PreTrainedTokenizerBase +# Pattern to match and strip think tags +# Handles two cases: +# 1. Full tags: ... +# 2. Only closing tag: ...content before... (when is in prompt) +THINK_TAG_PATTERN = re.compile(r".*?", re.DOTALL) +IMPLICIT_THINK_PATTERN = re.compile(r"^.*?", re.DOTALL) + @dataclass class ExtractedToolCallInformation: @@ -57,6 +65,35 @@ def supports_native_format(cls) -> bool: """ return cls.SUPPORTS_NATIVE_TOOL_FORMAT + @staticmethod + def strip_think_tags(text: str) -> str: + """ + Strip think tags from text. + + Handles two scenarios: + 1. Full tags: ... in output + 2. Only closing tag: ... when was in prompt + + Used as fallback when no reasoning parser is configured but the model + produces thinking tags. This prevents tool parsing failures with + models that use thinking tags (e.g., Ring-Mini-Linear-2.0 with hermes). + + Args: + text: Model output that may contain think tags + + Returns: + Text with think tags removed + """ + # First try to strip full tags + result = THINK_TAG_PATTERN.sub("", text) + + # If no full tags found but exists, strip implicit think + # (when was injected in the prompt) + if result == text and "" in text: + result = IMPLICIT_THINK_PATTERN.sub("", text) + + return result.strip() + def __init__(self, tokenizer: PreTrainedTokenizerBase | None = None): """ Initialize the tool parser. diff --git a/vllm_mlx/tool_parsers/hermes_tool_parser.py b/vllm_mlx/tool_parsers/hermes_tool_parser.py index c2ebcfe1..0605e640 100644 --- a/vllm_mlx/tool_parsers/hermes_tool_parser.py +++ b/vllm_mlx/tool_parsers/hermes_tool_parser.py @@ -31,14 +31,24 @@ class HermesToolParser(ToolParser): Supports Hermes tool call format: - {"name": "func", "arguments": {...}} - Sometimes with additional reasoning in + - Fallback: raw JSON {"name": "func", "arguments": {...}} (for models that omit tags) Used when --enable-auto-tool-choice --tool-call-parser hermes are set. """ + # Standard format: {"name": ..., "arguments": ...} TOOL_CALL_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + # Lenient format: followed by JSON (handles malformed tags) + TOOL_CALL_LENIENT_PATTERN = re.compile( + r'(.*?)", re.DOTALL ) + # Fallback pattern for raw JSON tool calls (without tags) + RAW_JSON_TOOL_PATTERN = re.compile( + r'\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]*\})\}', re.DOTALL + ) def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None @@ -49,11 +59,14 @@ def extract_tool_calls( tool_calls = [] cleaned_text = model_output + # Strip tags first (fallback when no reasoning parser) + cleaned_text = self.strip_think_tags(cleaned_text) + # Remove reasoning tags first (keep for content) - reasoning_matches = self.REASONING_PATTERN.findall(model_output) + reasoning_matches = self.REASONING_PATTERN.findall(cleaned_text) cleaned_text = self.REASONING_PATTERN.sub("", cleaned_text) - # Parse tool calls + # Parse tool calls with tags (primary format) matches = self.TOOL_CALL_PATTERN.findall(cleaned_text) for match in matches: try: @@ -78,6 +91,66 @@ def extract_tool_calls( if matches: cleaned_text = self.TOOL_CALL_PATTERN.sub("", cleaned_text).strip() + # Fallback 1: try lenient pattern for malformed tags like + if not tool_calls: + lenient_matches = self.TOOL_CALL_LENIENT_PATTERN.findall(cleaned_text) + for match in lenient_matches[:1]: # Only first to avoid hallucinations + try: + data = json.loads(match) + name = data.get("name", "") + arguments = data.get("arguments", {}) + if name: + tool_calls.append( + { + "id": generate_tool_id(), + "name": name, + "arguments": ( + json.dumps(arguments, ensure_ascii=False) + if isinstance(arguments, dict) + else str(arguments) + ), + } + ) + cleaned_text = self.TOOL_CALL_LENIENT_PATTERN.sub( + "", cleaned_text, count=1 + ).strip() + except json.JSONDecodeError: + continue + + # Fallback 2: try raw JSON format if no tagged tool calls found + # Only parse the FIRST valid tool call to avoid hallucinated multiple calls + if not tool_calls: + raw_matches = self.RAW_JSON_TOOL_PATTERN.findall(cleaned_text) + if raw_matches: + # Only take the first match to avoid hallucinated tool calls + name, args_str = raw_matches[0] + try: + arguments = json.loads(args_str) + # Validate: only accept if tool name exists in request tools + valid_tool = True + if request and "tools" in request: + tool_names = [ + t.get("function", {}).get("name", "") + for t in request.get("tools", []) + if isinstance(t, dict) + ] + valid_tool = name in tool_names + + if valid_tool and name: + tool_calls.append( + { + "id": generate_tool_id(), + "name": name, + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + # Remove the matched tool call from text + cleaned_text = self.RAW_JSON_TOOL_PATTERN.sub( + "", cleaned_text, count=1 + ).strip() + except json.JSONDecodeError: + pass + # Include reasoning in content if present if reasoning_matches: reasoning_text = " ".join(reasoning_matches) @@ -94,7 +167,7 @@ def extract_tool_calls( ) else: return ExtractedToolCallInformation( - tools_called=False, tool_calls=[], content=model_output + tools_called=False, tool_calls=[], content=cleaned_text ) def extract_tool_calls_streaming( @@ -110,25 +183,48 @@ def extract_tool_calls_streaming( """ Extract tool calls from streaming Hermes model output. """ - if "" not in current_text: - return {"content": delta_text} - - if "" in delta_text: - result = self.extract_tool_calls(current_text) - if result.tools_called: - return { - "tool_calls": [ - { - "index": i, - "id": tc["id"], - "type": "function", - "function": { - "name": tc["name"], - "arguments": tc["arguments"], - }, - } - for i, tc in enumerate(result.tool_calls) - ] - } - - return None + # Check for tagged tool calls + if "" in current_text: + if "" in delta_text: + result = self.extract_tool_calls(current_text, request) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + return None + + # Fallback: check for raw JSON tool calls (detect closing brace pattern) + # Look for complete JSON object with "name" and "arguments" + if '{"name":' in current_text and '"arguments":' in current_text: + # Check if we have a complete JSON object (ends with }}) + if delta_text.rstrip().endswith("}"): + result = self.extract_tool_calls(current_text, request) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + return None + + return {"content": delta_text} diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py index 57a14d31..fd69b96c 100644 --- a/vllm_mlx/tool_parsers/qwen_tool_parser.py +++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py @@ -50,10 +50,12 @@ def extract_tool_calls( Extract tool calls from a complete Qwen model response. """ tool_calls = [] - cleaned_text = model_output + + # Strip tags first (fallback when no reasoning parser) + cleaned_text = self.strip_think_tags(model_output) # Try bracket pattern first (Qwen3 style) - bracket_matches = self.BRACKET_PATTERN.findall(model_output) + bracket_matches = self.BRACKET_PATTERN.findall(cleaned_text) for name, args_str in bracket_matches: try: arguments = json.loads(args_str)