Skip to content

Commit 3b51f50

Browse files
committed
use ClientPool to prevent race conditions when using pylibmc as memcached package
1 parent eafb40b commit 3b51f50

1 file changed

Lines changed: 58 additions & 14 deletions

File tree

src/cachelib/memcached.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ def __init__(
4949
servers: _t.Any = None,
5050
default_timeout: int = 300,
5151
key_prefix: _t.Optional[str] = None,
52+
threads: int = 1,
53+
blocking: bool = False,
5254
):
5355
BaseCache.__init__(self, default_timeout)
5456
if servers is None or isinstance(servers, (list, tuple)):
5557
if servers is None:
5658
servers = ["127.0.0.1:11211"]
57-
self._client = self.import_preferred_memcache_lib(servers)
59+
self.pylibmc_used = False
60+
self._client = self.import_preferred_memcache_lib(servers, threads)
5861
if self._client is None:
5962
raise RuntimeError("no memcache module found")
6063
else:
@@ -81,7 +84,11 @@ def get(self, key: str) -> _t.Any:
8184
# checks for so long keys can occur because it's tested from user
8285
# submitted data etc we fail silently for getting.
8386
if _test_memcached_key(key):
84-
return self._client.get(key)
87+
if self.pylibmc_used:
88+
with self._client.reserve(block=self.blocking) as mc:
89+
return mc.get(self._normalize_key(key))
90+
else:
91+
return self._client.get(key)
8592

8693
def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
8794
key_mapping = {}
@@ -90,7 +97,11 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
9097
if _test_memcached_key(key):
9198
key_mapping[encoded_key] = key
9299
_keys = list(key_mapping)
93-
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
100+
if self.pylibmc_used:
101+
with self._client.reserve(block=self.blocking) as mc:
102+
d = rv = mc.get_multi(_keys) # type: _t.Dict[str, _t.Any]
103+
else:
104+
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
94105
if self.key_prefix:
95106
rv = {}
96107
for key, value in d.items():
@@ -104,14 +115,22 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
104115
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
105116
key = self._normalize_key(key)
106117
timeout = self._normalize_timeout(timeout)
107-
return bool(self._client.add(key, value, timeout))
118+
if self.pylibmc_used:
119+
with self._client.reserve(block=self.blocking) as mc:
120+
return bool(mc.add(key, value, timeout))
121+
else:
122+
return bool(self._client.add(key, value, timeout))
108123

109124
def set(
110125
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
111126
) -> _t.Optional[bool]:
112127
key = self._normalize_key(key)
113128
timeout = self._normalize_timeout(timeout)
114-
return bool(self._client.set(key, value, timeout))
129+
if self.pylibmc_used:
130+
with self._client.reserve(block=self.blocking) as mc:
131+
return bool(mc.set(key, value, timeout))
132+
else:
133+
return bool(self._client.set(key, value, timeout))
115134

116135
def get_many(self, *keys: str) -> _t.List[_t.Any]:
117136
d = self.get_dict(*keys)
@@ -126,16 +145,26 @@ def set_many(
126145
new_mapping[key] = value
127146

128147
timeout = self._normalize_timeout(timeout)
129-
failed_keys = self._client.set_multi(
130-
new_mapping, timeout
131-
) # type: _t.List[_t.Any]
148+
if self.pylibmc_used:
149+
with self._client.reserve(block=self.blocking) as mc:
150+
failed_keys = mc.set_multi(
151+
new_mapping, timeout
152+
) # type: _t.List[_t.Any]
153+
else:
154+
failed_keys = self._client.set_multi(
155+
new_mapping, timeout
156+
) # type: _t.List[_t.Any]
132157
k_normkey = zip(mapping.keys(), new_mapping.keys()) # noqa: B905
133158
return [k for k, nkey in k_normkey if nkey not in failed_keys]
134159

135160
def delete(self, key: str) -> bool:
136161
key = self._normalize_key(key)
137162
if _test_memcached_key(key):
138-
return bool(self._client.delete(key))
163+
if self.pylibmc_used:
164+
with self._client.reserve(block=self.blocking) as mc:
165+
return bool(mc.delete(key))
166+
else:
167+
return bool(self._client.delete(key))
139168
return False
140169

141170
def delete_many(self, *keys: str) -> _t.List[_t.Any]:
@@ -144,17 +173,29 @@ def delete_many(self, *keys: str) -> _t.List[_t.Any]:
144173
key = self._normalize_key(key)
145174
if _test_memcached_key(key):
146175
new_keys.append(key)
147-
self._client.delete_multi(new_keys)
176+
if self.pylibmc_used:
177+
with self._client.reserve(block=self.blocking) as mc:
178+
mc.delete_multi(new_keys)
179+
else:
180+
self._client.delete_multi(new_keys)
148181
return [k for k in new_keys if not self.has(k)]
149182

150183
def has(self, key: str) -> bool:
151184
key = self._normalize_key(key)
152185
if _test_memcached_key(key):
153-
return bool(self._client.append(key, ""))
186+
if self.pylibmc_used:
187+
with self._client.reserve(block=self.blocking) as mc:
188+
return bool(mc.append(key, ""))
189+
else:
190+
return bool(self._client.append(key, ""))
154191
return False
155192

156193
def clear(self) -> bool:
157-
return bool(self._client.flush_all())
194+
if self.pylibmc_used:
195+
with self._client.reserve(block=self.blocking) as mc:
196+
return bool(mc.flush_all())
197+
else:
198+
return bool(self._client.flush_all())
158199

159200
def inc(self, key: str, delta: int = 1) -> _t.Optional[int]:
160201
key = self._normalize_key(key)
@@ -166,14 +207,17 @@ def dec(self, key: str, delta: int = 1) -> _t.Optional[int]:
166207
value = (self._client.get(key) or 0) - delta
167208
return value if self.set(key, value) else None
168209

169-
def import_preferred_memcache_lib(self, servers: _t.Any) -> _t.Any:
210+
def import_preferred_memcache_lib(self, servers: _t.Any, threads: int) -> _t.Any:
170211
"""Returns an initialized memcache client. Used by the constructor."""
171212
try:
172213
import pylibmc # type: ignore
173214
except ImportError:
174215
pass
175216
else:
176-
return pylibmc.Client(servers)
217+
self.pylibmc_used = True
218+
_client_pool = pylibmc.ClientPool()
219+
_client_pool.fill(pylibmc.Client(servers), threads)
220+
return _client_pool
177221

178222
try:
179223
from google.appengine.api import memcache # type: ignore

0 commit comments

Comments
 (0)