Skip to content

Commit fff398f

Browse files
Replace manual LoggingContext usage with ModuleApi.defer_to_threadpool (#134)
1 parent a5d1d50 commit fff398f

File tree

1 file changed

+60
-56
lines changed

1 file changed

+60
-56
lines changed

s3_storage_provider.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,15 @@
2424
import botocore
2525
from botocore.config import Config
2626

27-
from twisted.internet import defer, reactor, threads
27+
from twisted.internet import defer, reactor
2828
from twisted.python.failure import Failure
2929
from twisted.python.threadpool import ThreadPool
3030

31-
from synapse.logging.context import LoggingContext, make_deferred_yieldable
31+
from synapse.logging.context import make_deferred_yieldable
32+
from synapse.module_api import ModuleApi
3233
from synapse.rest.media.v1._base import Responder
3334
from synapse.rest.media.v1.storage_provider import StorageProvider
3435

35-
# Synapse 1.13.0 moved current_context to a module-level function.
36-
try:
37-
from synapse.logging.context import current_context
38-
except ImportError:
39-
current_context = LoggingContext.current_context
40-
4136
logger = logging.getLogger("synapse.s3")
4237

4338

@@ -61,6 +56,7 @@ class S3StorageProviderBackend(StorageProvider):
6156
"""
6257

6358
def __init__(self, hs, config):
59+
self._module_api: ModuleApi = hs.get_module_api()
6460
self.cache_directory = hs.config.media.media_store_path
6561
self.bucket = config["bucket"]
6662
self.prefix = config["prefix"]
@@ -124,37 +120,45 @@ def _get_s3_client(self):
124120
self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs)
125121
return s3
126122

127-
def store_file(self, path, file_info):
123+
async def store_file(self, path, file_info):
128124
"""See StorageProvider.store_file"""
129125

130-
parent_logcontext = current_context()
131-
132-
def _store_file():
133-
with LoggingContext(parent_context=parent_logcontext):
134-
self._get_s3_client().upload_file(
135-
Filename=os.path.join(self.cache_directory, path),
136-
Bucket=self.bucket,
137-
Key=self.prefix + path,
138-
ExtraArgs=self.extra_args,
139-
)
140-
141-
return make_deferred_yieldable(
142-
threads.deferToThreadPool(reactor, self._s3_pool, _store_file)
126+
return await self._module_api.defer_to_threadpool(
127+
self._s3_pool,
128+
self._get_s3_client().upload_file,
129+
Filename=os.path.join(self.cache_directory, path),
130+
Bucket=self.bucket,
131+
Key=self.prefix + path,
132+
ExtraArgs=self.extra_args,
143133
)
144134

145-
def fetch(self, path, file_info):
135+
async def fetch(self, path, file_info):
146136
"""See StorageProvider.fetch"""
147-
logcontext = current_context()
148-
149137
d = defer.Deferred()
150138

151-
def _get_file():
152-
s3_download_task(
153-
self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d, logcontext
139+
# Don't await this directly, as it will resolve only once the streaming
140+
# download from S3 is concluded. Before that happens, we want to pass
141+
# execution back to Synapse to stream the file's chunks.
142+
#
143+
# We do, however, need to wrap in `run_in_background` to ensure that the
144+
# coroutine returned by `defer_to_threadpool` is used, and therefore
145+
# actually run.
146+
self._module_api.run_in_background(
147+
self._module_api.defer_to_threadpool(
148+
self._s3_pool,
149+
s3_download_task,
150+
self._get_s3_client(),
151+
self.bucket,
152+
self.prefix + path,
153+
self.extra_args,
154+
d,
154155
)
156+
)
155157

156-
self._s3_pool.callInThread(_get_file)
157-
return make_deferred_yieldable(d)
158+
# DO await on `d`, as it will resolve once a connection to S3 has been
159+
# opened. We only want to return to Synapse once we can start streaming
160+
# chunks.
161+
return await make_deferred_yieldable(d)
158162

159163
@staticmethod
160164
def parse_config(config):
@@ -202,7 +206,7 @@ def parse_config(config):
202206
return result
203207

204208

205-
def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
209+
def s3_download_task(s3_client, bucket, key, extra_args, deferred):
206210
"""Attempts to download a file from S3.
207211
208212
Args:
@@ -212,35 +216,35 @@ def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcon
212216
deferred (Deferred[_S3Responder|None]): If file exists
213217
resolved with an _S3Responder instance, if it doesn't
214218
exist then resolves with None.
215-
parent_logcontext (LoggingContext): the logcontext to report logs and metrics
216-
against.
219+
220+
Returns:
221+
A deferred which resolves to an _S3Responder if the file exists.
222+
Otherwise the deferred fails.
217223
"""
218-
with LoggingContext(parent_context=parent_logcontext):
219-
logger.info("Fetching %s from S3", key)
220-
221-
try:
222-
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
223-
resp = s3_client.get_object(
224-
Bucket=bucket,
225-
Key=key,
226-
SSECustomerKey=extra_args["SSECustomerKey"],
227-
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
228-
)
229-
else:
230-
resp = s3_client.get_object(Bucket=bucket, Key=key)
231-
232-
except botocore.exceptions.ClientError as e:
233-
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
234-
logger.info("Media %s not found in S3", key)
235-
reactor.callFromThread(deferred.callback, None)
236-
return
224+
logger.info("Fetching %s from S3", key)
225+
226+
try:
227+
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
228+
resp = s3_client.get_object(
229+
Bucket=bucket,
230+
Key=key,
231+
SSECustomerKey=extra_args["SSECustomerKey"],
232+
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
233+
)
234+
else:
235+
resp = s3_client.get_object(Bucket=bucket, Key=key)
237236

238-
reactor.callFromThread(deferred.errback, Failure())
237+
except botocore.exceptions.ClientError as e:
238+
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
239+
logger.info("Media %s not found in S3", key)
239240
return
240241

241-
producer = _S3Responder()
242-
reactor.callFromThread(deferred.callback, producer)
243-
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)
242+
reactor.callFromThread(deferred.errback, Failure())
243+
return
244+
245+
producer = _S3Responder()
246+
reactor.callFromThread(deferred.callback, producer)
247+
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)
244248

245249

246250
def _stream_to_producer(reactor, producer, body, status=None, timeout=None):

0 commit comments

Comments
 (0)