Skip to content

Commit 89be10f

Browse files
add ttl to RedisCache (#9068)
Add `ttl` (time to live) to `RedisCache`
1 parent 04bc5f3 commit 89be10f

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

libs/langchain/langchain/cache.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,25 @@ def __init__(self, database_path: str = ".langchain.db"):
216216
class RedisCache(BaseCache):
217217
"""Cache that uses Redis as a backend."""
218218

219-
# TODO - implement a TTL policy in Redis
220-
221-
def __init__(self, redis_: Any):
222-
"""Initialize by passing in Redis instance."""
219+
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
220+
"""
221+
Initialize an instance of RedisCache.
222+
223+
This method initializes an object with Redis caching capabilities.
224+
It takes a `redis_` parameter, which should be an instance of a Redis
225+
client class, allowing the object to interact with a Redis
226+
server for caching purposes.
227+
228+
Parameters:
229+
redis_ (Any): An instance of a Redis client class
230+
(e.g., redis.Redis) used for caching.
231+
This allows the object to communicate with a
232+
Redis server for caching operations.
233+
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
234+
If provided, it sets the time duration for how long cached
235+
items will remain valid. If not provided, cached items will not
236+
have an automatic expiration.
237+
"""
223238
try:
224239
from redis import Redis
225240
except ImportError:
@@ -230,6 +245,7 @@ def __init__(self, redis_: Any):
230245
if not isinstance(redis_, Redis):
231246
raise ValueError("Please pass in Redis object.")
232247
self.redis = redis_
248+
self.ttl = ttl
233249

234250
def _key(self, prompt: str, llm_string: str) -> str:
235251
"""Compute key from prompt and llm_string"""
@@ -261,12 +277,19 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
261277
return
262278
# Write to a Redis HASH
263279
key = self._key(prompt, llm_string)
264-
self.redis.hset(
265-
key,
266-
mapping={
267-
str(idx): generation.text for idx, generation in enumerate(return_val)
268-
},
269-
)
280+
281+
with self.redis.pipeline() as pipe:
282+
pipe.hset(
283+
key,
284+
mapping={
285+
str(idx): generation.text
286+
for idx, generation in enumerate(return_val)
287+
},
288+
)
289+
if self.ttl is not None:
290+
pipe.expire(key, self.ttl)
291+
292+
pipe.execute()
270293

271294
def clear(self, **kwargs: Any) -> None:
272295
"""Clear cache. If `asynchronous` is True, flush asynchronously."""

libs/langchain/tests/integration_tests/cache/test_redis_cache.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
REDIS_TEST_URL = "redis://localhost:6379"
1212

1313

14+
def test_redis_cache_ttl() -> None:
15+
import redis
16+
17+
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)
18+
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
19+
key = langchain.llm_cache._key("foo", "bar")
20+
assert langchain.llm_cache.redis.pttl(key) > 0
21+
22+
1423
def test_redis_cache() -> None:
1524
import redis
1625

0 commit comments

Comments
 (0)