Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,4 @@ libs/redis/docs/.Trash*
.idea/*
.vscode/settings.json
.python-version
tests/data
50 changes: 19 additions & 31 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ numpy = [
{ version = ">=1.26.0,<3", python = ">=3.12" },
]
pyyaml = ">=5.4,<7.0"
redis = "^5.0"
redis = "^6.0"
pydantic = "^2"
tenacity = ">=8.2.2"
ml-dtypes = ">=0.4.0,<1.0.0"
Expand Down Expand Up @@ -68,8 +68,8 @@ pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}
pre-commit = "^4.1.0"
mypy = "1.9.0"
nbval = "^0.11.0"
types-redis = "*"
types-pyyaml = "*"
types-pyopenssl = "*"
testcontainers = "^4.3.1"
cryptography = { version = ">=44.0.1", markers = "python_version > '3.9.1'" }

Expand Down
76 changes: 46 additions & 30 deletions redisvl/extensions/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
specific cache types such as LLM caches and embedding caches.
"""

from typing import Any, Dict, Optional
from collections.abc import Mapping
from typing import Any, Dict, Optional, Union

from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis import Redis # For backwards compatibility in type checking
from redis.cluster import RedisCluster

from redisvl.redis.connection import RedisConnectionFactory
from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster


class BaseCache:
Expand All @@ -19,14 +21,15 @@ class BaseCache:
including TTL management, connection handling, and basic cache operations.
"""

_redis_client: Optional[Redis]
_async_redis_client: Optional[AsyncRedis]
_redis_client: Optional[SyncRedisClient]
_async_redis_client: Optional[AsyncRedisClient]

def __init__(
self,
name: str,
ttl: Optional[int] = None,
redis_client: Optional[Redis] = None,
redis_client: Optional[SyncRedisClient] = None,
async_redis_client: Optional[AsyncRedisClient] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
):
Expand All @@ -36,7 +39,7 @@ def __init__(
name (str): The name of the cache.
ttl (Optional[int], optional): The time-to-live for records cached
in Redis. Defaults to None.
redis_client (Optional[Redis], optional): A redis client connection instance.
redis_client (Optional[SyncRedisClient], optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
Expand All @@ -53,14 +56,13 @@ def __init__(
}

# Initialize Redis clients
self._async_redis_client = None
self._async_redis_client = async_redis_client
self._redis_client = redis_client

if redis_client:
if redis_client or async_redis_client:
self._owns_redis_client = False
self._redis_client = redis_client
else:
self._owns_redis_client = True
self._redis_client = None # type: ignore

def _get_prefix(self) -> str:
"""Get the key prefix for Redis keys.
Expand Down Expand Up @@ -103,11 +105,11 @@ def set_ttl(self, ttl: Optional[int] = None) -> None:
else:
self._ttl = None

def _get_redis_client(self) -> Redis:
def _get_redis_client(self) -> SyncRedisClient:
"""Get or create a Redis client.

Returns:
Redis: A Redis client instance.
SyncRedisClient: A Redis client instance.
"""
if self._redis_client is None:
# Create new Redis client
Expand All @@ -116,22 +118,29 @@ def _get_redis_client(self) -> Redis:
self._redis_client = Redis.from_url(url, **kwargs) # type: ignore
return self._redis_client

async def _get_async_redis_client(self) -> AsyncRedis:
async def _get_async_redis_client(self) -> AsyncRedisClient:
"""Get or create an async Redis client.

Returns:
AsyncRedis: An async Redis client instance.
AsyncRedisClient: An async Redis client instance.
"""
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
client = self.redis_kwargs.get("redis_client")
if isinstance(client, Redis):

if client and isinstance(client, (Redis, RedisCluster)):
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
client
)
else:
url = self.redis_kwargs["redis_url"]
kwargs = self.redis_kwargs["connection_kwargs"]
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
url = str(self.redis_kwargs["redis_url"])
kwargs = self.redis_kwargs.get("connection_kwargs", {})
if not isinstance(kwargs, Mapping):
raise ValueError(
f"connection_kwargs must be a mapping, got {type(kwargs)}"
)
self._async_redis_client = (
RedisConnectionFactory.get_async_redis_connection(url, **kwargs)
)
return self._async_redis_client

def expire(self, key: str, ttl: Optional[int] = None) -> None:
Expand Down Expand Up @@ -183,7 +192,14 @@ def clear(self) -> None:
client.delete(*keys)
if cursor_int == 0: # Redis returns 0 when scan is complete
break
cursor = cursor_int # Update cursor for next iteration
# Cluster returns a dict of cursor values. We need to stop if these all
# come back as 0.
elif isinstance(cursor_int, Mapping):
cursor_values = list(cursor_int.values())
if all(v == 0 for v in cursor_values):
break
else:
cursor = cursor_int # Update cursor for next iteration

async def aclear(self) -> None:
"""Async clear the cache of all keys."""
Expand All @@ -193,7 +209,9 @@ async def aclear(self) -> None:
# Scan for all keys with our prefix
cursor = 0 # Start with cursor 0
while True:
cursor_int, keys = await client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore
cursor_int, keys = await client.scan(
cursor=cursor, match=f"{prefix}*", count=100
) # type: ignore
if keys:
await client.delete(*keys)
if cursor_int == 0: # Redis returns 0 when scan is complete
Expand All @@ -207,12 +225,10 @@ def disconnect(self) -> None:

if self._redis_client:
self._redis_client.close()
self._redis_client = None # type: ignore

if hasattr(self, "_async_redis_client") and self._async_redis_client:
# Use synchronous close for async client in synchronous context
self._async_redis_client.close() # type: ignore
self._async_redis_client = None # type: ignore
self._redis_client = None
# Async clients don't have a sync close method, so we just
# zero them out to allow garbage collection.
self._async_redis_client = None

async def adisconnect(self) -> None:
"""Async disconnect from Redis."""
Expand All @@ -221,9 +237,9 @@ async def adisconnect(self) -> None:

if self._redis_client:
self._redis_client.close()
self._redis_client = None # type: ignore
self._redis_client = None

if hasattr(self, "_async_redis_client") and self._async_redis_client:
# Use proper async close method
await self._async_redis_client.aclose() # type: ignore
self._async_redis_client = None # type: ignore
await self._async_redis_client.aclose()
self._async_redis_client = None
20 changes: 11 additions & 9 deletions redisvl/extensions/cache/embeddings/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Embeddings cache implementation for RedisVL."""

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from typing import Any, Awaitable, Dict, List, Optional, Tuple, cast

from redisvl.extensions.cache.base import BaseCache
from redisvl.extensions.cache.embeddings.schema import CacheEntry
from redisvl.redis.utils import convert_bytes, hashify
from redisvl.types import AsyncRedisClient, SyncRedisClient
from redisvl.utils.log import get_logger

logger = get_logger(__name__)


class EmbeddingsCache(BaseCache):
Expand All @@ -17,7 +18,8 @@ def __init__(
self,
name: str = "embedcache",
ttl: Optional[int] = None,
redis_client: Optional[Redis] = None,
redis_client: Optional[SyncRedisClient] = None,
async_redis_client: Optional[AsyncRedisClient] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
):
Expand All @@ -26,7 +28,7 @@ def __init__(
Args:
name (str): The name of the cache. Defaults to "embedcache".
ttl (Optional[int]): The time-to-live for cached embeddings. Defaults to None.
redis_client (Optional[Redis]): Redis client instance. Defaults to None.
redis_client (Optional[SyncRedisClient]): Redis client instance. Defaults to None.
redis_url (str): Redis URL for connection. Defaults to "redis://localhost:6379".
connection_kwargs (Dict[str, Any]): Redis connection arguments. Defaults to {}.

Expand Down Expand Up @@ -173,7 +175,7 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:
if data:
self.expire(key)

return self._process_cache_data(data)
return self._process_cache_data(data) # type: ignore

def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
"""Get multiple embeddings by their Redis keys.
Expand Down Expand Up @@ -570,7 +572,7 @@ async def aget_by_key(self, key: str) -> Optional[Dict[str, Any]]:
client = await self._get_async_redis_client()

# Get all fields
data = await client.hgetall(key)
data = await client.hgetall(key) # type: ignore

# Refresh TTL if data exists
if data:
Expand Down Expand Up @@ -608,7 +610,7 @@ async def amget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]
async with client.pipeline(transaction=False) as pipeline:
# Queue all hgetall operations
for key in keys:
await pipeline.hgetall(key)
pipeline.hgetall(key)
results = await pipeline.execute()

# Process results and refresh TTLs separately
Expand Down
Loading
Loading