Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions docs/integrations/databricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,27 @@ description: Guide to using instructor with Databricks models
First, install the required packages:

```bash
pip install instructor
uv pip install instructor openai
```

You'll need a Databricks API key and workspace URL which you can set as environment variables:
Set your Databricks workspace URL and token as environment variables:

```bash
export DATABRICKS_API_KEY=your_api_key_here
export DATABRICKS_HOST=your_workspace_url
export DATABRICKS_TOKEN="your_personal_access_token"
export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
```

`DATABRICKS_API_KEY` and `DATABRICKS_WORKSPACE_URL` are also supported if you prefer those names. The provider appends `/serving-endpoints` automatically, so the host only needs the base workspace URL.

## Basic Example

Here's how to extract structured data from Databricks models:

```python
import os
import instructor
from openai import OpenAI
from pydantic import BaseModel

# Initialize the client with Databricks base URL
# Initialize the client; host and token are read from the environment
client = instructor.from_provider(
"databricks/dbrx-instruct",
mode=instructor.Mode.TOOLS,
Expand All @@ -55,6 +55,8 @@ print(user)
# Output: UserExtract(name='Jason', age=25)
```

If you need to point at a different workspace or testing endpoint, pass `base_url="https://alt-workspace.cloud.databricks.com/serving-endpoints"`. The helper will use that value as-is without adding another suffix.

### Async Example

```python
Expand Down
88 changes: 88 additions & 0 deletions instructor/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
supported_providers = [
"openai",
"azure_openai",
"databricks",
"anthropic",
"google",
"generative-ai",
Expand Down Expand Up @@ -262,6 +263,93 @@ def from_provider(
)
raise

elif provider == "databricks":
try:
import os
import openai
from instructor import from_openai # type: ignore[attr-defined]

api_key = api_key or os.environ.get("DATABRICKS_TOKEN") or os.environ.get(
"DATABRICKS_API_KEY"
)
if not api_key:
from .core.exceptions import ConfigurationError

raise ConfigurationError(
"DATABRICKS_TOKEN is not set. "
"Set it with `export DATABRICKS_TOKEN=<your-token>` or `export DATABRICKS_API_KEY=<your-token>` "
"or pass it as kwarg `api_key=<your-token>`."
)

base_url = kwargs.pop("base_url", None)
if base_url is None:
base_url = (
os.environ.get("DATABRICKS_BASE_URL")
or os.environ.get("DATABRICKS_HOST")
or os.environ.get("DATABRICKS_WORKSPACE_URL")
)

if not base_url:
from .core.exceptions import ConfigurationError

raise ConfigurationError(
"DATABRICKS_HOST is not set. "
"Set it with `export DATABRICKS_HOST=<your-workspace-url>` or `export DATABRICKS_WORKSPACE_URL=<your-workspace-url>` "
"or pass `base_url=<your-workspace-url>`."
)

base_url = str(base_url).rstrip("/")
if not base_url.endswith("/serving-endpoints"):
base_url = f"{base_url}/serving-endpoints"

openai_client_kwargs = {}
for key in (
"organization",
"timeout",
"max_retries",
"default_headers",
"http_client",
"app_info",
):
if key in kwargs:
openai_client_kwargs[key] = kwargs.pop(key)

client = (
openai.AsyncOpenAI(
api_key=api_key, base_url=base_url, **openai_client_kwargs
)
if async_client
else openai.OpenAI(
api_key=api_key, base_url=base_url, **openai_client_kwargs
)
)
result = from_openai(
client,
model=model_name,
mode=mode if mode else instructor.Mode.TOOLS,
**kwargs,
)
logger.info(
"Client initialized",
extra={**provider_info, "status": "success"},
)
return result
except ImportError:
from .core.exceptions import ConfigurationError

raise ConfigurationError(
"The openai package is required to use the Databricks provider. "
"Install it with `pip install openai`."
) from None
except Exception as e:
logger.error(
"Error initializing %s client: %s",
provider,
e,
exc_info=True,
extra={**provider_info, "status": "error"},
)
raise
elif provider == "anthropic":
try:
import anthropic
Expand Down
145 changes: 145 additions & 0 deletions tests/test_auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,151 @@ def test_api_key_logging():
)


def test_databricks_provider_uses_environment_configuration():
"""Ensure Databricks provider pulls host and token from the environment."""
from unittest.mock import patch, MagicMock
import os

with patch("openai.OpenAI") as mock_openai_class:
mock_client = MagicMock()
mock_openai_class.return_value = mock_client

with patch("instructor.from_openai") as mock_from_openai:
mock_instructor = MagicMock()
mock_from_openai.return_value = mock_instructor

with patch.dict(
os.environ,
{
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
"DATABRICKS_TOKEN": "secret-token",
},
clear=True,
):
client = from_provider("databricks/dbrx-instruct")

mock_openai_class.assert_called_once()
_, kwargs = mock_openai_class.call_args
assert kwargs["api_key"] == "secret-token"
assert (
kwargs["base_url"]
== "https://example.cloud.databricks.com/serving-endpoints"
)
mock_from_openai.assert_called_once()
assert client is mock_instructor


def test_databricks_provider_respects_custom_base_url():
"""Ensure Databricks provider does not duplicate serving-endpoints suffix."""
from unittest.mock import patch, MagicMock
import os

with patch("openai.OpenAI") as mock_openai_class:
mock_client = MagicMock()
mock_openai_class.return_value = mock_client

with patch("instructor.from_openai") as mock_from_openai:
mock_instructor = MagicMock()
mock_from_openai.return_value = mock_instructor

with patch.dict(
os.environ,
{
"DATABRICKS_TOKEN": "secret-token",
},
clear=True,
):
client = from_provider(
"databricks/dbrx-instruct",
base_url="https://example.cloud.databricks.com/serving-endpoints",
)

_, kwargs = mock_openai_class.call_args
assert (
kwargs["base_url"]
== "https://example.cloud.databricks.com/serving-endpoints"
)
mock_from_openai.assert_called_once()
assert client is mock_instructor


def test_databricks_provider_async_client():
"""Ensure Databricks provider returns async client when requested."""
from unittest.mock import patch, MagicMock
import os

with patch("openai.AsyncOpenAI") as mock_async_openai_class:
mock_client = MagicMock()
mock_async_openai_class.return_value = mock_client

with patch("instructor.from_openai") as mock_from_openai:
mock_instructor = MagicMock()
mock_from_openai.return_value = mock_instructor

with patch.dict(
os.environ,
{
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
"DATABRICKS_TOKEN": "secret-token",
},
clear=True,
):
client = from_provider(
"databricks/dbrx-instruct", async_client=True
)

mock_async_openai_class.assert_called_once()
_, kwargs = mock_async_openai_class.call_args
assert (
kwargs["base_url"]
== "https://example.cloud.databricks.com/serving-endpoints"
)
assert kwargs["api_key"] == "secret-token"
mock_from_openai.assert_called_once()
assert client is mock_instructor


def test_databricks_provider_requires_token():
"""Ensure Databricks provider raises when no token is available."""
from instructor.core.exceptions import ConfigurationError
from unittest.mock import patch, MagicMock
import os

with patch("openai.OpenAI") as mock_openai_class:
mock_openai_class.return_value = MagicMock()
with patch("instructor.from_openai") as mock_from_openai:
mock_from_openai.return_value = MagicMock()
with patch.dict(
os.environ,
{
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
},
clear=True,
):
with pytest.raises(ConfigurationError):
from_provider("databricks/dbrx-instruct")


def test_databricks_provider_requires_host():
"""Ensure Databricks provider raises when no host is available."""
from instructor.core.exceptions import ConfigurationError
from unittest.mock import patch, MagicMock
import os

with patch("openai.OpenAI") as mock_openai_class:
mock_openai_class.return_value = MagicMock()
with patch("instructor.from_openai") as mock_from_openai:
mock_from_openai.return_value = MagicMock()
with patch.dict(
os.environ,
{
"DATABRICKS_TOKEN": "secret-token",
},
clear=True,
):
with pytest.raises(ConfigurationError):
from_provider("databricks/dbrx-instruct")

def test_genai_mode_parameter_passed_to_provider():
"""Test that mode parameter is correctly passed to provider functions."""
from unittest.mock import patch, MagicMock
Expand Down
Loading