Skip to content

Commit 86734af

Browse files
committed
fix: logic for determining reasoning model
1 parent 681edb7 commit 86734af

File tree

2 files changed

+213
-1
lines changed

2 files changed

+213
-1
lines changed

marimo/_server/ai/providers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,39 @@ class OpenAIProvider(
439439
DEFAULT_REASONING_EFFORT = "medium"
440440

441441
def _is_reasoning_model(self, model: str) -> bool:
442-
return model.startswith("o") or model.startswith("gpt-5")
442+
"""
443+
Check if reasoning_effort should be added to the request.
444+
Only add for actual OpenAI reasoning models, not for OpenAI-compatible APIs.
445+
446+
OpenAI-compatible APIs (identified by custom base_url) may not support
447+
the reasoning_effort parameter even if the model name suggests it's a
448+
reasoning model.
449+
"""
450+
import re
451+
452+
# Check for reasoning model patterns: o{digit} or gpt-5, with optional openai/ prefix
453+
reasoning_patterns = [
454+
r"^openai/o\d", # openai/o1, openai/o3, etc.
455+
r"^o\d", # o1, o3, etc.
456+
r"^openai/gpt-5", # openai/gpt-5*
457+
r"^gpt-5", # gpt-5*
458+
]
459+
460+
is_reasoning_model_name = any(
461+
re.match(pattern, model) for pattern in reasoning_patterns
462+
)
463+
464+
if not is_reasoning_model_name:
465+
return False
466+
467+
# If using a custom base_url that's not OpenAI, don't assume reasoning is supported
468+
if (
469+
self.config.base_url
470+
and "api.openai.com" not in self.config.base_url
471+
):
472+
return False
473+
474+
return True
443475

444476
def get_client(self, config: AnyProviderConfig) -> AsyncOpenAI:
445477
DependencyManager.openai.require(why="for AI assistance with OpenAI")

tests/_server/ai/test_providers.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,183 @@ def test_anthropic_extract_content_tool_call_id_mapping() -> None:
320320
tool_data, _ = result[0]
321321
assert isinstance(tool_data, dict)
322322
assert tool_data["toolCallId"] == "toolu_123"
323+
324+
325+
@pytest.mark.parametrize(
326+
("model_name", "base_url", "expected"),
327+
[
328+
pytest.param(
329+
"o1-mini",
330+
None,
331+
True,
332+
id="o1_mini_no_base_url",
333+
),
334+
pytest.param(
335+
"o1-preview",
336+
None,
337+
True,
338+
id="o1_preview_no_base_url",
339+
),
340+
pytest.param(
341+
"o1",
342+
None,
343+
True,
344+
id="o1_no_base_url",
345+
),
346+
pytest.param(
347+
"o1-2024-12-17",
348+
"https://api.openai.com/v1",
349+
True,
350+
id="o1_dated_openai_base_url",
351+
),
352+
pytest.param(
353+
"o3-mini",
354+
None,
355+
True,
356+
id="o3_mini_no_base_url",
357+
),
358+
pytest.param(
359+
"gpt-5-turbo",
360+
None,
361+
True,
362+
id="gpt5_turbo_no_base_url",
363+
),
364+
pytest.param(
365+
"gpt-5-preview",
366+
None,
367+
True,
368+
id="gpt5_preview_no_base_url",
369+
),
370+
pytest.param(
371+
"openai/o1-mini",
372+
None,
373+
True,
374+
id="openai_prefix_o1_mini_no_base_url",
375+
),
376+
pytest.param(
377+
"openai/o1-preview",
378+
None,
379+
True,
380+
id="openai_prefix_o1_preview_no_base_url",
381+
),
382+
pytest.param(
383+
"openai/gpt-5-turbo",
384+
None,
385+
True,
386+
id="openai_prefix_gpt5_no_base_url",
387+
),
388+
pytest.param(
389+
"o1-mini",
390+
"https://custom.api.com/v1",
391+
False,
392+
id="o1_custom_base_url",
393+
),
394+
pytest.param(
395+
"o1-preview",
396+
"https://litellm.proxy.com/api/v1",
397+
False,
398+
id="o1_litellm_proxy",
399+
),
400+
pytest.param(
401+
"gpt-4",
402+
None,
403+
False,
404+
id="gpt4_no_base_url",
405+
),
406+
pytest.param(
407+
"gpt-4o",
408+
None,
409+
False,
410+
id="gpt4o_no_base_url",
411+
),
412+
pytest.param(
413+
"gpt-4",
414+
"https://custom.api.com/v1",
415+
False,
416+
id="gpt4_custom_base_url",
417+
),
418+
pytest.param(
419+
"olive-model",
420+
None,
421+
False,
422+
id="model_starting_with_o_but_not_reasoning",
423+
),
424+
pytest.param(
425+
"openrouter/o1-mini",
426+
None,
427+
False,
428+
id="openrouter_prefix_not_openai",
429+
),
430+
],
431+
)
432+
def test_is_reasoning_model(
433+
model_name: str, base_url: str | None, expected: bool
434+
) -> None:
435+
"""Test that _is_reasoning_model correctly identifies reasoning models."""
436+
config = AnyProviderConfig(api_key="test-key", base_url=base_url)
437+
provider = OpenAIProvider(model_name, config)
438+
assert provider._is_reasoning_model(model_name) == expected
439+
440+
441+
@pytest.mark.parametrize(
442+
("model_name", "base_url", "expected_params"),
443+
[
444+
pytest.param(
445+
"o1-mini",
446+
"https://custom-openai-compatible.com/v1",
447+
{"max_tokens": 1000},
448+
id="reasoning_model_name_custom_api_no_reasoning",
449+
),
450+
pytest.param(
451+
"o1-preview",
452+
"https://litellm.proxy.com/api/v1",
453+
{"max_tokens": 1000},
454+
id="o1_preview_litellm_proxy_no_reasoning",
455+
),
456+
pytest.param(
457+
"o3-mini",
458+
"https://corporate-llm.internal/api",
459+
{"max_tokens": 1000},
460+
id="o3_mini_corporate_proxy_no_reasoning",
461+
),
462+
],
463+
)
464+
@patch("openai.AsyncOpenAI")
465+
async def test_openai_compatible_api_no_reasoning_effort(
466+
mock_openai_class, model_name: str, base_url: str, expected_params: dict
467+
) -> None:
468+
"""Test that OpenAI-compatible APIs don't get reasoning_effort even with reasoning model names."""
469+
# Setup mock
470+
mock_client = AsyncMock()
471+
mock_openai_class.return_value = mock_client
472+
mock_stream = AsyncMock()
473+
mock_client.chat.completions.create.return_value = mock_stream
474+
475+
# Create provider with custom base_url (simulating OpenAI-compatible API)
476+
config = AnyProviderConfig(api_key="test-key", base_url=base_url)
477+
provider = OpenAIProvider(model_name, config)
478+
479+
# Call stream_completion
480+
messages = [ChatMessage(role="user", content="test message")]
481+
await provider.stream_completion(messages, "system prompt", 1000, [])
482+
483+
# Verify the correct parameters were passed
484+
mock_client.chat.completions.create.assert_called_once()
485+
call_kwargs = mock_client.chat.completions.create.call_args[1]
486+
487+
# Check that reasoning_effort is NOT present
488+
assert "reasoning_effort" not in call_kwargs, (
489+
"reasoning_effort should not be present for OpenAI-compatible APIs"
490+
)
491+
492+
# Check that the expected parameters are present
493+
for param_name, param_value in expected_params.items():
494+
assert param_name in call_kwargs, (
495+
f"Expected parameter {param_name} not found"
496+
)
497+
assert call_kwargs[param_name] == param_value
498+
499+
# Ensure max_completion_tokens is not present when reasoning_effort is not used
500+
assert "max_completion_tokens" not in call_kwargs, (
501+
"max_completion_tokens should not be present when reasoning_effort is not used"
502+
)

0 commit comments

Comments
 (0)