Skip to content

Commit e8ec027

Browse files
committed
Refactor get_task_function code in tasks.py
All the logic was being handled in a single function with lots of conditional branches, which was hard to read and reason about. this splits it in more digestable auxiliary functions and tries to make the logic more clear.
1 parent e770559 commit e8ec027

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

pulpcore/tasking/tasks.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import traceback
88
import tempfile
9+
from functools import partial
910
from gettext import gettext as _
1011
from contextlib import contextmanager
1112
from asgiref.sync import sync_to_async, async_to_sync
@@ -94,8 +95,8 @@ async def _aexecute_task(task):
9495
await sync_to_async(task.set_running)()
9596
domain = get_domain()
9697
try:
97-
coroutine = get_task_function(task, ensure_coroutine=True)
98-
result = await coroutine
98+
task_coroutine_fn = await aget_task_function(task)
99+
result = await task_coroutine_fn()
99100
except Exception:
100101
exc_type, exc, tb = sys.exc_info()
101102
await sync_to_async(task.set_failed)(exc, tb)
@@ -148,42 +149,72 @@ def log_task_failed(task, exc_type, exc, tb, domain):
148149
_logger.info("\n".join(traceback.format_list(traceback.extract_tb(tb))))
149150

150151

151-
def get_task_function(task, ensure_coroutine=False):
152+
async def aget_task_function(task):
153+
"""Get and handle task function running from ASYNC context.
154+
155+
This exists to handle the combinations of:
156+
* context: sync | async
157+
* Task.function: regular-function | coroutine-function
158+
* Task.immediate: True | False
159+
"""
160+
func, is_coroutine_fn = _load_function(task)
161+
162+
if task.immediate and not is_coroutine_fn:
163+
raise ValueError("Immediate tasks must be async functions.")
164+
elif not task.immediate:
165+
raise ValueError("Non-immediate tasks can't run in async context.")
166+
167+
return _add_timeout_to(func, task.pk)
168+
169+
170+
def get_task_function(task):
171+
"""Get and handle task function running from SYNC context.
172+
173+
This exists to handle the combinations of:
174+
* context: sync | async
175+
* Task.function: regular-function | coroutine-function
176+
* Task.immediate: True | False
177+
"""
178+
func, is_coroutine_fn = _load_function(task)
179+
180+
if task.immediate and not is_coroutine_fn:
181+
raise ValueError("Immediate tasks must be async functions.")
182+
183+
# no sync wrapper required
184+
if not is_coroutine_fn:
185+
return func
186+
187+
# async function in sync context requires wrapper
188+
if task.immediate:
189+
coro_fn_with_timeout = _add_timeout_to(func, task.pk)
190+
return async_to_sync(coro_fn_with_timeout)
191+
return async_to_sync(func)
192+
193+
194+
def _load_function(task):
152195
module_name, function_name = task.name.rsplit(".", 1)
153196
module = importlib.import_module(module_name)
154197
func = getattr(module, function_name)
155198
args = task.enc_args or ()
156199
kwargs = task.enc_kwargs or {}
157-
immediate = task.immediate
200+
201+
func_with_args = partial(func, *args, **kwargs)
158202
is_coroutine_fn = asyncio.iscoroutinefunction(func)
203+
return func_with_args, is_coroutine_fn
159204

160-
if immediate and not is_coroutine_fn:
161-
raise ValueError("Immediate tasks must be async functions.")
162205

163-
if ensure_coroutine:
164-
if not is_coroutine_fn:
165-
return sync_to_async(func)(*args, **kwargs)
166-
coro = func(*args, **kwargs)
167-
if immediate:
168-
coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT)
169-
return coro
170-
else: # ensure normal function
171-
if not is_coroutine_fn:
172-
return lambda: func(*args, **kwargs)
173-
174-
async def task_wrapper(): # asyncio.wait_for + async_to_sync requires wrapping
175-
coro = func(*args, **kwargs)
176-
if immediate:
177-
coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT)
178-
try:
179-
return await coro
180-
except asyncio.TimeoutError:
181-
msg_template = "Immediate task %s timed out after %s seconds."
182-
error_msg = msg_template % (task.pk, IMMEDIATE_TIMEOUT)
183-
_logger.info(error_msg)
184-
raise RuntimeError(error_msg)
185-
186-
return async_to_sync(task_wrapper)
206+
def _add_timeout_to(coro_fn, task_pk):
207+
208+
async def _wrapper():
209+
try:
210+
return await asyncio.wait_for(coro_fn(), timeout=IMMEDIATE_TIMEOUT)
211+
except asyncio.TimeoutError:
212+
msg_template = "Immediate task %s timed out after %s seconds."
213+
error_msg = msg_template % (task_pk, IMMEDIATE_TIMEOUT)
214+
_logger.info(error_msg)
215+
raise RuntimeError(error_msg)
216+
217+
return _wrapper
187218

188219

189220
def dispatch(

0 commit comments

Comments
 (0)