|
3 | 3 | import ssl |
4 | 4 | from collections.abc import Callable |
5 | 5 | from datetime import timedelta |
6 | | -from typing import TYPE_CHECKING, Any, Union |
| 6 | +from typing import Any, Union, cast |
7 | 7 |
|
8 | 8 | import redis |
9 | 9 | from redis import RedisError |
|
18 | 18 |
|
19 | 19 | from configs import dify_config |
20 | 20 | from dify_app import DifyApp |
| 21 | +from extensions.redis_names import ( |
| 22 | + normalize_redis_key_prefix, |
| 23 | + serialize_redis_name, |
| 24 | + serialize_redis_name_arg, |
| 25 | + serialize_redis_name_args, |
| 26 | +) |
21 | 27 | from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol |
22 | 28 | from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel |
23 | 29 | from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel |
24 | 30 | from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel |
25 | 31 |
|
26 | | -if TYPE_CHECKING: |
27 | | - from redis.lock import Lock |
28 | | - |
29 | 32 | logger = logging.getLogger(__name__) |
30 | 33 |
|
31 | 34 |
|
| 35 | +_normalize_redis_key_prefix = normalize_redis_key_prefix |
| 36 | +_serialize_redis_name = serialize_redis_name |
| 37 | +_serialize_redis_name_arg = serialize_redis_name_arg |
| 38 | +_serialize_redis_name_args = serialize_redis_name_args |
| 39 | + |
| 40 | + |
32 | 41 | class RedisClientWrapper: |
33 | 42 | """ |
34 | 43 | A wrapper class for the Redis client that addresses the issue where the global |
@@ -59,68 +68,148 @@ def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None: |
59 | 68 | if self._client is None: |
60 | 69 | self._client = client |
61 | 70 |
|
62 | | - if TYPE_CHECKING: |
63 | | - # Type hints for IDE support and static analysis |
64 | | - # These are not executed at runtime but provide type information |
65 | | - def get(self, name: str | bytes) -> Any: ... |
66 | | - |
67 | | - def set( |
68 | | - self, |
69 | | - name: str | bytes, |
70 | | - value: Any, |
71 | | - ex: int | None = None, |
72 | | - px: int | None = None, |
73 | | - nx: bool = False, |
74 | | - xx: bool = False, |
75 | | - keepttl: bool = False, |
76 | | - get: bool = False, |
77 | | - exat: int | None = None, |
78 | | - pxat: int | None = None, |
79 | | - ) -> Any: ... |
80 | | - |
81 | | - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... |
82 | | - def setnx(self, name: str | bytes, value: Any) -> Any: ... |
83 | | - def delete(self, *names: str | bytes) -> Any: ... |
84 | | - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... |
85 | | - def expire( |
86 | | - self, |
87 | | - name: str | bytes, |
88 | | - time: int | timedelta, |
89 | | - nx: bool = False, |
90 | | - xx: bool = False, |
91 | | - gt: bool = False, |
92 | | - lt: bool = False, |
93 | | - ) -> Any: ... |
94 | | - def lock( |
95 | | - self, |
96 | | - name: str, |
97 | | - timeout: float | None = None, |
98 | | - sleep: float = 0.1, |
99 | | - blocking: bool = True, |
100 | | - blocking_timeout: float | None = None, |
101 | | - thread_local: bool = True, |
102 | | - ) -> Lock: ... |
103 | | - def zadd( |
104 | | - self, |
105 | | - name: str | bytes, |
106 | | - mapping: dict[str | bytes | int | float, float | int | str | bytes], |
107 | | - nx: bool = False, |
108 | | - xx: bool = False, |
109 | | - ch: bool = False, |
110 | | - incr: bool = False, |
111 | | - gt: bool = False, |
112 | | - lt: bool = False, |
113 | | - ) -> Any: ... |
114 | | - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... |
115 | | - def zcard(self, name: str | bytes) -> Any: ... |
116 | | - def getdel(self, name: str | bytes) -> Any: ... |
117 | | - def pubsub(self) -> PubSub: ... |
118 | | - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... |
119 | | - |
120 | | - def __getattr__(self, item: str) -> Any: |
| 71 | + def _require_client(self) -> redis.Redis | RedisCluster: |
121 | 72 | if self._client is None: |
122 | 73 | raise RuntimeError("Redis client is not initialized. Call init_app first.") |
123 | | - return getattr(self._client, item) |
| 74 | + return self._client |
| 75 | + |
| 76 | + def _get_prefix(self) -> str: |
| 77 | + return dify_config.REDIS_KEY_PREFIX |
| 78 | + |
| 79 | + def get(self, name: str | bytes) -> Any: |
| 80 | + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) |
| 81 | + |
| 82 | + def set( |
| 83 | + self, |
| 84 | + name: str | bytes, |
| 85 | + value: Any, |
| 86 | + ex: int | None = None, |
| 87 | + px: int | None = None, |
| 88 | + nx: bool = False, |
| 89 | + xx: bool = False, |
| 90 | + keepttl: bool = False, |
| 91 | + get: bool = False, |
| 92 | + exat: int | None = None, |
| 93 | + pxat: int | None = None, |
| 94 | + ) -> Any: |
| 95 | + return self._require_client().set( |
| 96 | + _serialize_redis_name_arg(name, self._get_prefix()), |
| 97 | + value, |
| 98 | + ex=ex, |
| 99 | + px=px, |
| 100 | + nx=nx, |
| 101 | + xx=xx, |
| 102 | + keepttl=keepttl, |
| 103 | + get=get, |
| 104 | + exat=exat, |
| 105 | + pxat=pxat, |
| 106 | + ) |
| 107 | + |
| 108 | + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: |
| 109 | + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) |
| 110 | + |
| 111 | + def setnx(self, name: str | bytes, value: Any) -> Any: |
| 112 | + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) |
| 113 | + |
| 114 | + def delete(self, *names: str | bytes) -> Any: |
| 115 | + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) |
| 116 | + |
| 117 | + def incr(self, name: str | bytes, amount: int = 1) -> Any: |
| 118 | + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) |
| 119 | + |
| 120 | + def expire( |
| 121 | + self, |
| 122 | + name: str | bytes, |
| 123 | + time: int | timedelta, |
| 124 | + nx: bool = False, |
| 125 | + xx: bool = False, |
| 126 | + gt: bool = False, |
| 127 | + lt: bool = False, |
| 128 | + ) -> Any: |
| 129 | + return self._require_client().expire( |
| 130 | + _serialize_redis_name_arg(name, self._get_prefix()), |
| 131 | + time, |
| 132 | + nx=nx, |
| 133 | + xx=xx, |
| 134 | + gt=gt, |
| 135 | + lt=lt, |
| 136 | + ) |
| 137 | + |
| 138 | + def exists(self, *names: str | bytes) -> Any: |
| 139 | + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) |
| 140 | + |
| 141 | + def ttl(self, name: str | bytes) -> Any: |
| 142 | + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) |
| 143 | + |
| 144 | + def getdel(self, name: str | bytes) -> Any: |
| 145 | + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) |
| 146 | + |
| 147 | + def lock( |
| 148 | + self, |
| 149 | + name: str, |
| 150 | + timeout: float | None = None, |
| 151 | + sleep: float = 0.1, |
| 152 | + blocking: bool = True, |
| 153 | + blocking_timeout: float | None = None, |
| 154 | + thread_local: bool = True, |
| 155 | + ) -> Any: |
| 156 | + return self._require_client().lock( |
| 157 | + _serialize_redis_name(name, self._get_prefix()), |
| 158 | + timeout=timeout, |
| 159 | + sleep=sleep, |
| 160 | + blocking=blocking, |
| 161 | + blocking_timeout=blocking_timeout, |
| 162 | + thread_local=thread_local, |
| 163 | + ) |
| 164 | + |
| 165 | + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: |
| 166 | + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) |
| 167 | + |
| 168 | + def hgetall(self, name: str | bytes) -> Any: |
| 169 | + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) |
| 170 | + |
| 171 | + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: |
| 172 | + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) |
| 173 | + |
| 174 | + def hlen(self, name: str | bytes) -> Any: |
| 175 | + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) |
| 176 | + |
| 177 | + def zadd( |
| 178 | + self, |
| 179 | + name: str | bytes, |
| 180 | + mapping: dict[str | bytes | int | float, float | int | str | bytes], |
| 181 | + nx: bool = False, |
| 182 | + xx: bool = False, |
| 183 | + ch: bool = False, |
| 184 | + incr: bool = False, |
| 185 | + gt: bool = False, |
| 186 | + lt: bool = False, |
| 187 | + ) -> Any: |
| 188 | + return self._require_client().zadd( |
| 189 | + _serialize_redis_name_arg(name, self._get_prefix()), |
| 190 | + cast(Any, mapping), |
| 191 | + nx=nx, |
| 192 | + xx=xx, |
| 193 | + ch=ch, |
| 194 | + incr=incr, |
| 195 | + gt=gt, |
| 196 | + lt=lt, |
| 197 | + ) |
| 198 | + |
| 199 | + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: |
| 200 | + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) |
| 201 | + |
| 202 | + def zcard(self, name: str | bytes) -> Any: |
| 203 | + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) |
| 204 | + |
| 205 | + def pubsub(self) -> PubSub: |
| 206 | + return self._require_client().pubsub() |
| 207 | + |
| 208 | + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: |
| 209 | + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) |
| 210 | + |
| 211 | + def __getattr__(self, item: str) -> Any: |
| 212 | + return getattr(self._require_client(), item) |
124 | 213 |
|
125 | 214 |
|
126 | 215 | redis_client: RedisClientWrapper = RedisClientWrapper() |
|
0 commit comments