Skip to content

Commit d552134

Browse files
committed
wip: try splitting the functions
1 parent d0f38e6 commit d552134

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

pulpcore/tasking/tasks.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _execute_task(task):
7373
domain = get_domain()
7474
try:
7575
log_task_start(task, domain)
76-
task_function = get_task_function(task, in_async_context=False)
76+
task_function = get_task_function(task)
7777
result = task_function()
7878
except Exception:
7979
exc_type, exc, tb = sys.exc_info()
@@ -94,7 +94,7 @@ async def _aexecute_task(task):
9494
await sync_to_async(task.set_running)()
9595
domain = get_domain()
9696
try:
97-
coroutine = get_task_function(task, in_async_context=True)
97+
coroutine = await aget_task_function(task)
9898
result = await coroutine
9999
except Exception:
100100
exc_type, exc, tb = sys.exc_info()
@@ -148,44 +148,56 @@ def log_task_failed(task, exc_type, exc, tb, domain):
148148
_logger.info("\n".join(traceback.format_list(traceback.extract_tb(tb))))
149149

150150

151-
def get_task_function(task, in_async_context):
152-
"""Wraps task function according to the sync or async context we are in."""
153-
module_name, function_name = task.name.rsplit(".", 1)
154-
module = importlib.import_module(module_name)
155-
func = getattr(module, function_name)
151+
async def aget_task_function(task):
152+
"""Get and handle task function running from async context."""
153+
func, is_coroutine_fn = _load_function(task.name)
156154
args = task.enc_args or ()
157155
kwargs = task.enc_kwargs or {}
158156
immediate = task.immediate
159-
is_coroutine_fn = asyncio.iscoroutinefunction(func)
160157

161-
if immediate and not is_coroutine_fn:
158+
if not immediate:
159+
raise ValueError("Non-immediate tasks can't run in async context.")
160+
elif not is_coroutine_fn:
162161
raise ValueError("Immediate tasks must be async functions.")
163162

164-
if not immediate and in_async_context:
165-
raise ValueError("Non-immediate tasks can't run in async context.")
163+
coro = func(*args, **kwargs)
164+
return _add_timeout_to(coro, task.pk)
166165

167-
if in_async_context: # has to be immediate
168-
coro = func(*args, **kwargs)
169-
coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT)
170-
return coro
171-
else:
172-
if not is_coroutine_fn:
173-
return lambda: func(*args, **kwargs)
174-
175-
# async function in sync context needs adapter
176-
async def task_wrapper(): # asyncio.wait_for + async_to_sync requires wrapping
177-
coro = func(*args, **kwargs)
178-
if immediate:
179-
coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT)
180-
try:
181-
return await coro
182-
except asyncio.TimeoutError:
183-
msg_template = "Immediate task %s timed out after %s seconds."
184-
error_msg = msg_template % (task.pk, IMMEDIATE_TIMEOUT)
185-
_logger.info(error_msg)
186-
raise RuntimeError(error_msg)
187-
188-
return async_to_sync(task_wrapper)
166+
167+
def get_task_function(task):
168+
"""Get and handle task function running from sync context."""
169+
func, is_coroutine_fn = _load_function(task.name)
170+
args = task.enc_args or ()
171+
kwargs = task.enc_kwargs or {}
172+
immediate = task.immediate
173+
174+
# no wrapper required
175+
if not is_coroutine_fn:
176+
return lambda: func(*args, **kwargs)
177+
178+
# async function in sync context requires wrapper
179+
coro = func(*args, **kwargs)
180+
final_coro = _add_timeout_to(coro, task.pk) if immediate else coro
181+
return async_to_sync(final_coro)
182+
183+
184+
def _load_function(task_name):
185+
module_name, function_name = task_name.rsplit(".", 1)
186+
module = importlib.import_module(module_name)
187+
func = getattr(module, function_name)
188+
is_coroutine_fn = asyncio.iscoroutinefunction(func)
189+
return func, is_coroutine_fn
190+
191+
192+
async def _add_timeout_to(coro, task_pk):
193+
coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT)
194+
try:
195+
return await coro
196+
except asyncio.TimeoutError:
197+
msg_template = "Immediate task %s timed out after %s seconds."
198+
error_msg = msg_template % (task_pk, IMMEDIATE_TIMEOUT)
199+
_logger.info(error_msg)
200+
raise RuntimeError(error_msg)
189201

190202

191203
def dispatch(

0 commit comments

Comments
 (0)