Skip to content

Commit 0902337

Browse files
authored
fix: OpenAI provider in from_provider ignores base_url kwarg (#1971)
1 parent 069699b commit 0902337

2 files changed

Lines changed: 94 additions & 2 deletions

File tree

instructor/auto_client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,29 @@ def from_provider(
160160
import openai
161161
from instructor import from_openai # type: ignore[attr-defined]
162162

163+
# Extract base_url and other OpenAI client parameters from kwargs
164+
base_url = kwargs.pop("base_url", None)
165+
openai_client_kwargs = {}
166+
for key in (
167+
"organization",
168+
"timeout",
169+
"max_retries",
170+
"default_headers",
171+
"http_client",
172+
"app_info",
173+
):
174+
if key in kwargs:
175+
openai_client_kwargs[key] = kwargs.pop(key)
176+
177+
# Build client kwargs, including base_url if provided
178+
client_kwargs = {"api_key": api_key, **openai_client_kwargs}
179+
if base_url is not None:
180+
client_kwargs["base_url"] = base_url
181+
163182
client = (
164-
openai.AsyncOpenAI(api_key=api_key)
183+
openai.AsyncOpenAI(**client_kwargs)
165184
if async_client
166-
else openai.OpenAI(api_key=api_key)
185+
else openai.OpenAI(**client_kwargs)
167186
)
168187
result = from_openai(
169188
client,

tests/test_auto_client.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,79 @@ def test_api_key_logging():
270270
)
271271

272272

273+
def test_openai_provider_respects_base_url():
274+
"""Ensure OpenAI provider passes base_url to client constructor."""
275+
from unittest.mock import patch, MagicMock
276+
277+
with patch("openai.OpenAI") as mock_openai_class:
278+
mock_client = MagicMock()
279+
mock_openai_class.return_value = mock_client
280+
281+
with patch("instructor.from_openai") as mock_from_openai:
282+
mock_instructor = MagicMock()
283+
mock_from_openai.return_value = mock_instructor
284+
285+
client = from_provider(
286+
"openai/gpt-4",
287+
base_url="https://api.example.com/v1",
288+
api_key="test-key",
289+
)
290+
291+
_, kwargs = mock_openai_class.call_args
292+
assert kwargs["base_url"] == "https://api.example.com/v1"
293+
assert kwargs["api_key"] == "test-key"
294+
mock_from_openai.assert_called_once()
295+
assert client is mock_instructor
296+
297+
298+
def test_openai_provider_async_client_with_base_url():
299+
"""Ensure OpenAI provider passes base_url to async client constructor."""
300+
from unittest.mock import patch, MagicMock
301+
302+
with patch("openai.AsyncOpenAI") as mock_async_openai_class:
303+
mock_client = MagicMock()
304+
mock_async_openai_class.return_value = mock_client
305+
306+
with patch("instructor.from_openai") as mock_from_openai:
307+
mock_instructor = MagicMock()
308+
mock_from_openai.return_value = mock_instructor
309+
310+
client = from_provider(
311+
"openai/gpt-4",
312+
async_client=True,
313+
base_url="https://api.example.com/v1",
314+
api_key="test-key",
315+
)
316+
317+
mock_async_openai_class.assert_called_once()
318+
_, kwargs = mock_async_openai_class.call_args
319+
assert kwargs["base_url"] == "https://api.example.com/v1"
320+
assert kwargs["api_key"] == "test-key"
321+
mock_from_openai.assert_called_once()
322+
assert client is mock_instructor
323+
324+
325+
def test_openai_provider_without_base_url():
326+
"""Ensure OpenAI provider works without base_url (defaults to api.openai.com)."""
327+
from unittest.mock import patch, MagicMock
328+
329+
with patch("openai.OpenAI") as mock_openai_class:
330+
mock_client = MagicMock()
331+
mock_openai_class.return_value = mock_client
332+
333+
with patch("instructor.from_openai") as mock_from_openai:
334+
mock_instructor = MagicMock()
335+
mock_from_openai.return_value = mock_instructor
336+
337+
client = from_provider("openai/gpt-4", api_key="test-key")
338+
339+
_, kwargs = mock_openai_class.call_args
340+
assert "base_url" not in kwargs
341+
assert kwargs["api_key"] == "test-key"
342+
mock_from_openai.assert_called_once()
343+
assert client is mock_instructor
344+
345+
273346
def test_databricks_provider_uses_environment_configuration():
274347
"""Ensure Databricks provider pulls host and token from the environment."""
275348
from unittest.mock import patch, MagicMock

0 commit comments

Comments
 (0)