Skip to content

Commit 736880e

Browse files
authored
feat: support configurable redis key prefix (#35139)
1 parent bd7a9b5 commit 736880e

19 files changed

Lines changed: 522 additions & 74 deletions

File tree

api/.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
5757
REDIS_SSL_KEYFILE=
5858
# Path to client private key file for SSL authentication
5959
REDIS_DB=0
60+
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
61+
# Leave empty to preserve current unprefixed behavior.
62+
REDIS_KEY_PREFIX=
6063

6164
# redis Sentinel configuration.
6265
REDIS_USE_SENTINEL=false

api/configs/middleware/cache/redis_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
3232
default=0,
3333
)
3434

35+
REDIS_KEY_PREFIX: str = Field(
36+
description="Optional global prefix for Redis keys, topics, and transport artifacts",
37+
default="",
38+
)
39+
3540
REDIS_USE_SSL: bool = Field(
3641
description="Enable SSL/TLS for the Redis connection",
3742
default=False,

api/extensions/ext_celery.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99

1010
from configs import dify_config
1111
from dify_app import DifyApp
12+
from extensions.redis_names import normalize_redis_key_prefix
1213

1314

1415
class _CelerySentinelKwargsDict(TypedDict):
1516
socket_timeout: float | None
1617
password: str | None
1718

1819

19-
class CelerySentinelTransportDict(TypedDict):
20+
class CelerySentinelTransportDict(TypedDict, total=False):
2021
master_name: str | None
2122
sentinel_kwargs: _CelerySentinelKwargsDict
23+
global_keyprefix: str
2224

2325

2426
class CelerySSLOptionsDict(TypedDict):
@@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
6163

6264
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
6365
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
66+
transport_options: CelerySentinelTransportDict | dict[str, Any]
6467
if dify_config.CELERY_USE_SENTINEL:
65-
return CelerySentinelTransportDict(
68+
transport_options = CelerySentinelTransportDict(
6669
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
6770
sentinel_kwargs=_CelerySentinelKwargsDict(
6871
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
6972
password=dify_config.CELERY_SENTINEL_PASSWORD,
7073
),
7174
)
72-
return {}
75+
else:
76+
transport_options = {}
77+
78+
global_keyprefix = get_celery_redis_global_keyprefix()
79+
if global_keyprefix:
80+
transport_options["global_keyprefix"] = global_keyprefix
81+
82+
return transport_options
83+
84+
85+
def get_celery_redis_global_keyprefix() -> str | None:
86+
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
87+
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
88+
if not normalized_prefix:
89+
return None
90+
return f"{normalized_prefix}:"
7391

7492

7593
def init_app(app: DifyApp) -> Celery:

api/extensions/ext_redis.py

Lines changed: 153 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ssl
44
from collections.abc import Callable
55
from datetime import timedelta
6-
from typing import TYPE_CHECKING, Any, Union
6+
from typing import Any, Union, cast
77

88
import redis
99
from redis import RedisError
@@ -18,17 +18,26 @@
1818

1919
from configs import dify_config
2020
from dify_app import DifyApp
21+
from extensions.redis_names import (
22+
normalize_redis_key_prefix,
23+
serialize_redis_name,
24+
serialize_redis_name_arg,
25+
serialize_redis_name_args,
26+
)
2127
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
2228
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
2329
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
2430
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
2531

26-
if TYPE_CHECKING:
27-
from redis.lock import Lock
28-
2932
logger = logging.getLogger(__name__)
3033

3134

35+
_normalize_redis_key_prefix = normalize_redis_key_prefix
36+
_serialize_redis_name = serialize_redis_name
37+
_serialize_redis_name_arg = serialize_redis_name_arg
38+
_serialize_redis_name_args = serialize_redis_name_args
39+
40+
3241
class RedisClientWrapper:
3342
"""
3443
A wrapper class for the Redis client that addresses the issue where the global
@@ -59,68 +68,148 @@ def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None:
5968
if self._client is None:
6069
self._client = client
6170

62-
if TYPE_CHECKING:
63-
# Type hints for IDE support and static analysis
64-
# These are not executed at runtime but provide type information
65-
def get(self, name: str | bytes) -> Any: ...
66-
67-
def set(
68-
self,
69-
name: str | bytes,
70-
value: Any,
71-
ex: int | None = None,
72-
px: int | None = None,
73-
nx: bool = False,
74-
xx: bool = False,
75-
keepttl: bool = False,
76-
get: bool = False,
77-
exat: int | None = None,
78-
pxat: int | None = None,
79-
) -> Any: ...
80-
81-
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
82-
def setnx(self, name: str | bytes, value: Any) -> Any: ...
83-
def delete(self, *names: str | bytes) -> Any: ...
84-
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
85-
def expire(
86-
self,
87-
name: str | bytes,
88-
time: int | timedelta,
89-
nx: bool = False,
90-
xx: bool = False,
91-
gt: bool = False,
92-
lt: bool = False,
93-
) -> Any: ...
94-
def lock(
95-
self,
96-
name: str,
97-
timeout: float | None = None,
98-
sleep: float = 0.1,
99-
blocking: bool = True,
100-
blocking_timeout: float | None = None,
101-
thread_local: bool = True,
102-
) -> Lock: ...
103-
def zadd(
104-
self,
105-
name: str | bytes,
106-
mapping: dict[str | bytes | int | float, float | int | str | bytes],
107-
nx: bool = False,
108-
xx: bool = False,
109-
ch: bool = False,
110-
incr: bool = False,
111-
gt: bool = False,
112-
lt: bool = False,
113-
) -> Any: ...
114-
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
115-
def zcard(self, name: str | bytes) -> Any: ...
116-
def getdel(self, name: str | bytes) -> Any: ...
117-
def pubsub(self) -> PubSub: ...
118-
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
119-
120-
def __getattr__(self, item: str) -> Any:
71+
def _require_client(self) -> redis.Redis | RedisCluster:
12172
if self._client is None:
12273
raise RuntimeError("Redis client is not initialized. Call init_app first.")
123-
return getattr(self._client, item)
74+
return self._client
75+
76+
def _get_prefix(self) -> str:
77+
return dify_config.REDIS_KEY_PREFIX
78+
79+
def get(self, name: str | bytes) -> Any:
80+
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
81+
82+
def set(
83+
self,
84+
name: str | bytes,
85+
value: Any,
86+
ex: int | None = None,
87+
px: int | None = None,
88+
nx: bool = False,
89+
xx: bool = False,
90+
keepttl: bool = False,
91+
get: bool = False,
92+
exat: int | None = None,
93+
pxat: int | None = None,
94+
) -> Any:
95+
return self._require_client().set(
96+
_serialize_redis_name_arg(name, self._get_prefix()),
97+
value,
98+
ex=ex,
99+
px=px,
100+
nx=nx,
101+
xx=xx,
102+
keepttl=keepttl,
103+
get=get,
104+
exat=exat,
105+
pxat=pxat,
106+
)
107+
108+
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
109+
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
110+
111+
def setnx(self, name: str | bytes, value: Any) -> Any:
112+
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
113+
114+
def delete(self, *names: str | bytes) -> Any:
115+
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
116+
117+
def incr(self, name: str | bytes, amount: int = 1) -> Any:
118+
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
119+
120+
def expire(
121+
self,
122+
name: str | bytes,
123+
time: int | timedelta,
124+
nx: bool = False,
125+
xx: bool = False,
126+
gt: bool = False,
127+
lt: bool = False,
128+
) -> Any:
129+
return self._require_client().expire(
130+
_serialize_redis_name_arg(name, self._get_prefix()),
131+
time,
132+
nx=nx,
133+
xx=xx,
134+
gt=gt,
135+
lt=lt,
136+
)
137+
138+
def exists(self, *names: str | bytes) -> Any:
139+
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
140+
141+
def ttl(self, name: str | bytes) -> Any:
142+
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
143+
144+
def getdel(self, name: str | bytes) -> Any:
145+
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
146+
147+
def lock(
148+
self,
149+
name: str,
150+
timeout: float | None = None,
151+
sleep: float = 0.1,
152+
blocking: bool = True,
153+
blocking_timeout: float | None = None,
154+
thread_local: bool = True,
155+
) -> Any:
156+
return self._require_client().lock(
157+
_serialize_redis_name(name, self._get_prefix()),
158+
timeout=timeout,
159+
sleep=sleep,
160+
blocking=blocking,
161+
blocking_timeout=blocking_timeout,
162+
thread_local=thread_local,
163+
)
164+
165+
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
166+
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
167+
168+
def hgetall(self, name: str | bytes) -> Any:
169+
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
170+
171+
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
172+
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
173+
174+
def hlen(self, name: str | bytes) -> Any:
175+
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
176+
177+
def zadd(
178+
self,
179+
name: str | bytes,
180+
mapping: dict[str | bytes | int | float, float | int | str | bytes],
181+
nx: bool = False,
182+
xx: bool = False,
183+
ch: bool = False,
184+
incr: bool = False,
185+
gt: bool = False,
186+
lt: bool = False,
187+
) -> Any:
188+
return self._require_client().zadd(
189+
_serialize_redis_name_arg(name, self._get_prefix()),
190+
cast(Any, mapping),
191+
nx=nx,
192+
xx=xx,
193+
ch=ch,
194+
incr=incr,
195+
gt=gt,
196+
lt=lt,
197+
)
198+
199+
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
200+
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
201+
202+
def zcard(self, name: str | bytes) -> Any:
203+
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
204+
205+
def pubsub(self) -> PubSub:
206+
return self._require_client().pubsub()
207+
208+
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
209+
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
210+
211+
def __getattr__(self, item: str) -> Any:
212+
return getattr(self._require_client(), item)
124213

125214

126215
redis_client: RedisClientWrapper = RedisClientWrapper()

api/extensions/redis_names.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from configs import dify_config
2+
3+
4+
def normalize_redis_key_prefix(prefix: str | None) -> str:
5+
"""Normalize the configured Redis key prefix for consistent runtime use."""
6+
if prefix is None:
7+
return ""
8+
return prefix.strip()
9+
10+
11+
def get_redis_key_prefix() -> str:
12+
"""Read and normalize the current Redis key prefix from config."""
13+
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
14+
15+
16+
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
17+
"""Convert a logical Redis name into the physical name used in Redis."""
18+
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
19+
if not normalized_prefix:
20+
return name
21+
return f"{normalized_prefix}:{name}"
22+
23+
24+
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
25+
"""Prefix string Redis names while preserving bytes inputs unchanged."""
26+
if isinstance(name, bytes):
27+
return name
28+
return serialize_redis_name(name, prefix)
29+
30+
31+
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
32+
return tuple(serialize_redis_name_arg(name, prefix) for name in names)

api/libs/broadcast_channel/redis/channel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44

5+
from extensions.redis_names import serialize_redis_name
56
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
67
from redis import Redis, RedisCluster
78

@@ -32,12 +33,13 @@ class Topic:
3233
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
3334
self._client = redis_client
3435
self._topic = topic
36+
self._redis_topic = serialize_redis_name(topic)
3537

3638
def as_producer(self) -> Producer:
3739
return self
3840

3941
def publish(self, payload: bytes) -> None:
40-
self._client.publish(self._topic, payload)
42+
self._client.publish(self._redis_topic, payload)
4143

4244
def as_subscriber(self) -> Subscriber:
4345
return self
@@ -46,7 +48,7 @@ def subscribe(self) -> Subscription:
4648
return _RedisSubscription(
4749
client=self._client,
4850
pubsub=self._client.pubsub(),
49-
topic=self._topic,
51+
topic=self._redis_topic,
5052
)
5153

5254

0 commit comments

Comments
 (0)