Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/agno/agno/knowledge/embedder/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _response(self, text: str) -> CreateEmbeddingResponse:
}
if self.user is not None:
_request_params["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
_request_params["dimensions"] = self.dimensions
if self.request_params:
_request_params.update(self.request_params)
Expand Down Expand Up @@ -131,7 +131,7 @@ async def _aresponse(self, text: str) -> CreateEmbeddingResponse:
}
if self.user is not None:
_request_params["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
_request_params["dimensions"] = self.dimensions
if self.request_params:
_request_params.update(self.request_params)
Expand Down Expand Up @@ -181,7 +181,7 @@ async def async_get_embeddings_batch_and_usage(
}
if self.user is not None:
req["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
req["dimensions"] = self.dimensions
if self.request_params:
req.update(self.request_params)
Expand Down
8 changes: 4 additions & 4 deletions libs/agno/agno/knowledge/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def response(self, text: str) -> CreateEmbeddingResponse:
}
if self.user is not None:
_request_params["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
_request_params["dimensions"] = self.dimensions
if self.request_params:
_request_params.update(self.request_params)
Expand Down Expand Up @@ -106,7 +106,7 @@ async def async_get_embedding(self, text: str) -> List[float]:
}
if self.user is not None:
req["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
req["dimensions"] = self.dimensions
if self.request_params:
req.update(self.request_params)
Expand All @@ -126,7 +126,7 @@ async def async_get_embedding_and_usage(self, text: str):
}
if self.user is not None:
req["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
req["dimensions"] = self.dimensions
if self.request_params:
req.update(self.request_params)
Expand Down Expand Up @@ -166,7 +166,7 @@ async def async_get_embeddings_batch_and_usage(
}
if self.user is not None:
req["user"] = self.user
if self.id.startswith("text-embedding-3"):
if self.dimensions is not None:
req["dimensions"] = self.dimensions
if self.request_params:
req.update(self.request_params)
Expand Down
106 changes: 106 additions & 0 deletions libs/agno/tests/unit/embedder/test_openai_dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import importlib.util
from types import SimpleNamespace

import pytest

from agno.knowledge.embedder.openai import OpenAIEmbedder


@pytest.mark.skipif(not importlib.util.find_spec("openai"), reason="openai package not installed")
def test_dimensions_propagated_when_explicitly_set():
"""Ensure dimensions parameter is passed when explicitly set, regardless of model."""

class DummyEmbeddings:
def __init__(self):
self.last_kwargs = None

def create(self, **kwargs):
self.last_kwargs = kwargs
dims = kwargs.get("dimensions", 1)
return SimpleNamespace(
data=[SimpleNamespace(embedding=[0.0] * dims)],
usage=None,
)

class DummyClient:
def __init__(self):
self.embeddings = DummyEmbeddings()

embedder = OpenAIEmbedder(id="text-embedding-v4", dimensions=512)
embedder.openai_client = DummyClient()

_ = embedder.get_embedding("hello world")

assert embedder.openai_client.embeddings.last_kwargs is not None, "Embeddings request not captured"
assert (
embedder.openai_client.embeddings.last_kwargs.get("dimensions") == 512
), "dimensions parameter not propagated when explicitly set"


@pytest.mark.skipif(not importlib.util.find_spec("openai"), reason="openai package not installed")
def test_dimensions_propagated_for_any_model():
"""Ensure dimensions parameter is passed for ANY model when explicitly set (future-proof)."""

class DummyEmbeddings:
def __init__(self):
self.last_kwargs = None

def create(self, **kwargs):
self.last_kwargs = kwargs
dims = kwargs.get("dimensions", 1536)
return SimpleNamespace(
data=[SimpleNamespace(embedding=[0.0] * dims)],
usage=None,
)

class DummyClient:
def __init__(self):
self.embeddings = DummyEmbeddings()

embedder = OpenAIEmbedder(id="text-embedding-ada-002", dimensions=256)
embedder.openai_client = DummyClient()

_ = embedder.get_embedding("test")

assert embedder.openai_client.embeddings.last_kwargs.get("dimensions") == 256, (
"dimensions should be passed for legacy models too"
)

embedder2 = OpenAIEmbedder(id="text-embedding-v5-ultra", dimensions=2048)
embedder2.openai_client = DummyClient()

_ = embedder2.get_embedding("test")

assert embedder2.openai_client.embeddings.last_kwargs.get("dimensions") == 2048, (
"dimensions should be passed for any future models"
)


@pytest.mark.skipif(not importlib.util.find_spec("openai"), reason="openai package not installed")
def test_dimensions_not_passed_when_none():
"""Ensure dimensions parameter is NOT passed when set to None (respects model defaults)."""

class DummyEmbeddings:
def __init__(self):
self.last_kwargs = None

def create(self, **kwargs):
self.last_kwargs = kwargs
return SimpleNamespace(
data=[SimpleNamespace(embedding=[0.0] * 1536)],
usage=None,
)

class DummyClient:
def __init__(self):
self.embeddings = DummyEmbeddings()

embedder = OpenAIEmbedder(id="text-embedding-3-small")
embedder.dimensions = None
embedder.openai_client = DummyClient()

_ = embedder.get_embedding("test")

assert "dimensions" not in embedder.openai_client.embeddings.last_kwargs, (
"dimensions should NOT be passed when None (use model default)"
)