@@ -216,10 +216,25 @@ def __init__(self, database_path: str = ".langchain.db"):
216216class RedisCache (BaseCache ):
217217 """Cache that uses Redis as a backend."""
218218
219- # TODO - implement a TTL policy in Redis
220-
221- def __init__ (self , redis_ : Any ):
222- """Initialize by passing in Redis instance."""
219+ def __init__ (self , redis_ : Any , * , ttl : Optional [int ] = None ):
220+ """
221+ Initialize an instance of RedisCache.
222+
223+ This method initializes an object with Redis caching capabilities.
224+ It takes a `redis_` parameter, which should be an instance of a Redis
225+ client class, allowing the object to interact with a Redis
226+ server for caching purposes.
227+
228+ Parameters:
229+ redis_ (Any): An instance of a Redis client class
230+ (e.g., redis.Redis) used for caching.
231+ This allows the object to communicate with a
232+ Redis server for caching operations.
233+ ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
234+ If provided, it sets the time duration for how long cached
235+ items will remain valid. If not provided, cached items will not
236+ have an automatic expiration.
237+ """
223238 try :
224239 from redis import Redis
225240 except ImportError :
@@ -230,6 +245,7 @@ def __init__(self, redis_: Any):
230245 if not isinstance (redis_ , Redis ):
231246 raise ValueError ("Please pass in Redis object." )
232247 self .redis = redis_
248+ self .ttl = ttl
233249
234250 def _key (self , prompt : str , llm_string : str ) -> str :
235251 """Compute key from prompt and llm_string"""
@@ -261,12 +277,19 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
261277 return
262278 # Write to a Redis HASH
263279 key = self ._key (prompt , llm_string )
264- self .redis .hset (
265- key ,
266- mapping = {
267- str (idx ): generation .text for idx , generation in enumerate (return_val )
268- },
269- )
280+
281+ with self .redis .pipeline () as pipe :
282+ pipe .hset (
283+ key ,
284+ mapping = {
285+ str (idx ): generation .text
286+ for idx , generation in enumerate (return_val )
287+ },
288+ )
289+ if self .ttl is not None :
290+ pipe .expire (key , self .ttl )
291+
292+ pipe .execute ()
270293
271294 def clear (self , ** kwargs : Any ) -> None :
272295 """Clear cache. If `asynchronous` is True, flush asynchronously."""
0 commit comments