2424import botocore
2525from botocore .config import Config
2626
27- from twisted .internet import defer , reactor , threads
27+ from twisted .internet import defer , reactor
2828from twisted .python .failure import Failure
2929from twisted .python .threadpool import ThreadPool
3030
31- from synapse .logging .context import LoggingContext , make_deferred_yieldable
31+ from synapse .logging .context import make_deferred_yieldable
32+ from synapse .module_api import ModuleApi
3233from synapse .rest .media .v1 ._base import Responder
3334from synapse .rest .media .v1 .storage_provider import StorageProvider
3435
35- # Synapse 1.13.0 moved current_context to a module-level function.
36- try :
37- from synapse .logging .context import current_context
38- except ImportError :
39- current_context = LoggingContext .current_context
40-
4136logger = logging .getLogger ("synapse.s3" )
4237
4338
@@ -61,6 +56,7 @@ class S3StorageProviderBackend(StorageProvider):
6156 """
6257
6358 def __init__ (self , hs , config ):
59+ self ._module_api : ModuleApi = hs .get_module_api ()
6460 self .cache_directory = hs .config .media .media_store_path
6561 self .bucket = config ["bucket" ]
6662 self .prefix = config ["prefix" ]
@@ -124,37 +120,45 @@ def _get_s3_client(self):
124120 self ._s3_client = s3 = b3_session .client ("s3" , ** self .api_kwargs )
125121 return s3
126122
127- def store_file (self , path , file_info ):
123+ async def store_file (self , path , file_info ):
128124 """See StorageProvider.store_file"""
129125
130- parent_logcontext = current_context ()
131-
132- def _store_file ():
133- with LoggingContext (parent_context = parent_logcontext ):
134- self ._get_s3_client ().upload_file (
135- Filename = os .path .join (self .cache_directory , path ),
136- Bucket = self .bucket ,
137- Key = self .prefix + path ,
138- ExtraArgs = self .extra_args ,
139- )
140-
141- return make_deferred_yieldable (
142- threads .deferToThreadPool (reactor , self ._s3_pool , _store_file )
126+ return await self ._module_api .defer_to_threadpool (
127+ self ._s3_pool ,
128+ self ._get_s3_client ().upload_file ,
129+ Filename = os .path .join (self .cache_directory , path ),
130+ Bucket = self .bucket ,
131+ Key = self .prefix + path ,
132+ ExtraArgs = self .extra_args ,
143133 )
144134
145- def fetch (self , path , file_info ):
135+ async def fetch (self , path , file_info ):
146136 """See StorageProvider.fetch"""
147- logcontext = current_context ()
148-
149137 d = defer .Deferred ()
150138
151- def _get_file ():
152- s3_download_task (
153- self ._get_s3_client (), self .bucket , self .prefix + path , self .extra_args , d , logcontext
139+ # Don't await this directly, as it will resolve only once the streaming
140+ # download from S3 is concluded. Before that happens, we want to pass
141+ # execution back to Synapse to stream the file's chunks.
142+ #
143+ # We do, however, need to wrap in `run_in_background` to ensure that the
144+ # coroutine returned by `defer_to_threadpool` is used, and therefore
145+ # actually run.
146+ self ._module_api .run_in_background (
147+ self ._module_api .defer_to_threadpool (
148+ self ._s3_pool ,
149+ s3_download_task ,
150+ self ._get_s3_client (),
151+ self .bucket ,
152+ self .prefix + path ,
153+ self .extra_args ,
154+ d ,
154155 )
156+ )
155157
156- self ._s3_pool .callInThread (_get_file )
157- return make_deferred_yieldable (d )
158+ # DO await on `d`, as it will resolve once a connection to S3 has been
159+ # opened. We only want to return to Synapse once we can start streaming
160+ # chunks.
161+ return await make_deferred_yieldable (d )
158162
159163 @staticmethod
160164 def parse_config (config ):
@@ -202,7 +206,7 @@ def parse_config(config):
202206 return result
203207
204208
205- def s3_download_task (s3_client , bucket , key , extra_args , deferred , parent_logcontext ):
209+ def s3_download_task (s3_client , bucket , key , extra_args , deferred ):
206210 """Attempts to download a file from S3.
207211
208212 Args:
@@ -212,35 +216,35 @@ def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcon
212216 deferred (Deferred[_S3Responder|None]): If file exists
213217 resolved with an _S3Responder instance, if it doesn't
214218 exist then resolves with None.
215- parent_logcontext (LoggingContext): the logcontext to report logs and metrics
216- against.
219+
220+ Returns:
221+ A deferred which resolves to an _S3Responder if the file exists.
222+ Otherwise the deferred fails.
217223 """
218- with LoggingContext (parent_context = parent_logcontext ):
219- logger .info ("Fetching %s from S3" , key )
220-
221- try :
222- if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args :
223- resp = s3_client .get_object (
224- Bucket = bucket ,
225- Key = key ,
226- SSECustomerKey = extra_args ["SSECustomerKey" ],
227- SSECustomerAlgorithm = extra_args ["SSECustomerAlgorithm" ],
228- )
229- else :
230- resp = s3_client .get_object (Bucket = bucket , Key = key )
231-
232- except botocore .exceptions .ClientError as e :
233- if e .response ["Error" ]["Code" ] in ("404" , "NoSuchKey" ,):
234- logger .info ("Media %s not found in S3" , key )
235- reactor .callFromThread (deferred .callback , None )
236- return
224+ logger .info ("Fetching %s from S3" , key )
225+
226+ try :
227+ if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args :
228+ resp = s3_client .get_object (
229+ Bucket = bucket ,
230+ Key = key ,
231+ SSECustomerKey = extra_args ["SSECustomerKey" ],
232+ SSECustomerAlgorithm = extra_args ["SSECustomerAlgorithm" ],
233+ )
234+ else :
235+ resp = s3_client .get_object (Bucket = bucket , Key = key )
237236
238- reactor .callFromThread (deferred .errback , Failure ())
237+ except botocore .exceptions .ClientError as e :
238+ if e .response ["Error" ]["Code" ] in ("404" , "NoSuchKey" ,):
239+ logger .info ("Media %s not found in S3" , key )
239240 return
240241
241- producer = _S3Responder ()
242- reactor .callFromThread (deferred .callback , producer )
243- _stream_to_producer (reactor , producer , resp ["Body" ], timeout = 90.0 )
242+ reactor .callFromThread (deferred .errback , Failure ())
243+ return
244+
245+ producer = _S3Responder ()
246+ reactor .callFromThread (deferred .callback , producer )
247+ _stream_to_producer (reactor , producer , resp ["Body" ], timeout = 90.0 )
244248
245249
246250def _stream_to_producer (reactor , producer , body , status = None , timeout = None ):
0 commit comments