@@ -73,7 +73,7 @@ def _execute_task(task):
7373 domain = get_domain ()
7474 try :
7575 log_task_start (task , domain )
76- task_function = get_task_function (task , in_async_context = False )
76+ task_function = get_task_function (task )
7777 result = task_function ()
7878 except Exception :
7979 exc_type , exc , tb = sys .exc_info ()
@@ -94,7 +94,7 @@ async def _aexecute_task(task):
9494 await sync_to_async (task .set_running )()
9595 domain = get_domain ()
9696 try :
97- coroutine = get_task_function (task , in_async_context = True )
97+ coroutine = await aget_task_function (task )
9898 result = await coroutine
9999 except Exception :
100100 exc_type , exc , tb = sys .exc_info ()
@@ -148,44 +148,56 @@ def log_task_failed(task, exc_type, exc, tb, domain):
148148 _logger .info ("\n " .join (traceback .format_list (traceback .extract_tb (tb ))))
149149
150150
151- def get_task_function (task , in_async_context ):
152- """Wraps task function according to the sync or async context we are in."""
153- module_name , function_name = task .name .rsplit ("." , 1 )
154- module = importlib .import_module (module_name )
155- func = getattr (module , function_name )
151+ async def aget_task_function (task ):
152+ """Get and handle task function running from async context."""
153+ func , is_coroutine_fn = _load_function (task .name )
156154 args = task .enc_args or ()
157155 kwargs = task .enc_kwargs or {}
158156 immediate = task .immediate
159- is_coroutine_fn = asyncio .iscoroutinefunction (func )
160157
161- if immediate and not is_coroutine_fn :
158+ if not immediate :
159+ raise ValueError ("Non-immediate tasks can't run in async context." )
160+ elif not is_coroutine_fn :
162161 raise ValueError ("Immediate tasks must be async functions." )
163162
164- if not immediate and in_async_context :
165- raise ValueError ( "Non-immediate tasks can't run in async context." )
163+ coro = func ( * args , ** kwargs )
164+ return _add_timeout_to ( coro , task . pk )
166165
167- if in_async_context : # has to be immediate
168- coro = func (* args , ** kwargs )
169- coro = asyncio .wait_for (coro , timeout = IMMEDIATE_TIMEOUT )
170- return coro
171- else :
172- if not is_coroutine_fn :
173- return lambda : func (* args , ** kwargs )
174-
175- # async function in sync context needs adapter
176- async def task_wrapper (): # asyncio.wait_for + async_to_sync requires wrapping
177- coro = func (* args , ** kwargs )
178- if immediate :
179- coro = asyncio .wait_for (coro , timeout = IMMEDIATE_TIMEOUT )
180- try :
181- return await coro
182- except asyncio .TimeoutError :
183- msg_template = "Immediate task %s timed out after %s seconds."
184- error_msg = msg_template % (task .pk , IMMEDIATE_TIMEOUT )
185- _logger .info (error_msg )
186- raise RuntimeError (error_msg )
187-
188- return async_to_sync (task_wrapper )
166+
167+ def get_task_function (task ):
168+ """Get and handle task function running from sync context."""
169+ func , is_coroutine_fn = _load_function (task .name )
170+ args = task .enc_args or ()
171+ kwargs = task .enc_kwargs or {}
172+ immediate = task .immediate
173+
174+ # no wrapper required
175+ if not is_coroutine_fn :
176+ return lambda : func (* args , ** kwargs )
177+
178+ # async function in sync context requires wrapper
179+ coro = func (* args , ** kwargs )
180+ final_coro = _add_timeout_to (coro , task .pk ) if immediate else coro
181+ return async_to_sync (final_coro )
182+
183+
184+ def _load_function (task_name ):
185+ module_name , function_name = task_name .rsplit ("." , 1 )
186+ module = importlib .import_module (module_name )
187+ func = getattr (module , function_name )
188+ is_coroutine_fn = asyncio .iscoroutinefunction (func )
189+ return func , is_coroutine_fn
190+
191+
192+ async def _add_timeout_to (coro , task_pk ):
193+ coro = asyncio .wait_for (coro , timeout = IMMEDIATE_TIMEOUT )
194+ try :
195+ return await coro
196+ except asyncio .TimeoutError :
197+ msg_template = "Immediate task %s timed out after %s seconds."
198+ error_msg = msg_template % (task_pk , IMMEDIATE_TIMEOUT )
199+ _logger .info (error_msg )
200+ raise RuntimeError (error_msg )
189201
190202
191203def dispatch (
0 commit comments