|
6 | 6 | import sys |
7 | 7 | import traceback |
8 | 8 | import tempfile |
| 9 | +from functools import partial |
9 | 10 | from gettext import gettext as _ |
10 | 11 | from contextlib import contextmanager |
11 | 12 | from asgiref.sync import sync_to_async, async_to_sync |
@@ -94,8 +95,8 @@ async def _aexecute_task(task): |
94 | 95 | await sync_to_async(task.set_running)() |
95 | 96 | domain = get_domain() |
96 | 97 | try: |
97 | | - coroutine = get_task_function(task, ensure_coroutine=True) |
98 | | - result = await coroutine |
| 98 | + task_coroutine_fn = await aget_task_function(task) |
| 99 | + result = await task_coroutine_fn() |
99 | 100 | except Exception: |
100 | 101 | exc_type, exc, tb = sys.exc_info() |
101 | 102 | await sync_to_async(task.set_failed)(exc, tb) |
@@ -148,42 +149,72 @@ def log_task_failed(task, exc_type, exc, tb, domain): |
148 | 149 | _logger.info("\n".join(traceback.format_list(traceback.extract_tb(tb)))) |
149 | 150 |
|
150 | 151 |
|
151 | | -def get_task_function(task, ensure_coroutine=False): |
| 152 | +async def aget_task_function(task): |
| 153 | + """Get and handle task function running from ASYNC context. |
| 154 | +
|
| 155 | + This exists to handle the combinations of: |
| 156 | + * context: sync | async |
| 157 | + * Task.function: regular-function | coroutine-function |
| 158 | + * Task.immediate: True | False |
| 159 | + """ |
| 160 | + func, is_coroutine_fn = _load_function(task) |
| 161 | + |
| 162 | + if task.immediate and not is_coroutine_fn: |
| 163 | + raise ValueError("Immediate tasks must be async functions.") |
| 164 | + elif not task.immediate: |
| 165 | + raise ValueError("Non-immediate tasks can't run in async context.") |
| 166 | + |
| 167 | + return _add_timeout_to(func, task.pk) |
| 168 | + |
| 169 | + |
| 170 | +def get_task_function(task): |
| 171 | + """Get and handle task function running from SYNC context. |
| 172 | +
|
| 173 | + This exists to handle the combinations of: |
| 174 | + * context: sync | async |
| 175 | + * Task.function: regular-function | coroutine-function |
| 176 | + * Task.immediate: True | False |
| 177 | + """ |
| 178 | + func, is_coroutine_fn = _load_function(task) |
| 179 | + |
| 180 | + if task.immediate and not is_coroutine_fn: |
| 181 | + raise ValueError("Immediate tasks must be async functions.") |
| 182 | + |
| 183 | + # no sync wrapper required |
| 184 | + if not is_coroutine_fn: |
| 185 | + return func |
| 186 | + |
| 187 | + # async function in sync context requires wrapper |
| 188 | + if task.immediate: |
| 189 | + coro_fn_with_timeout = _add_timeout_to(func, task.pk) |
| 190 | + return async_to_sync(coro_fn_with_timeout) |
| 191 | + return async_to_sync(func) |
| 192 | + |
| 193 | + |
| 194 | +def _load_function(task): |
152 | 195 | module_name, function_name = task.name.rsplit(".", 1) |
153 | 196 | module = importlib.import_module(module_name) |
154 | 197 | func = getattr(module, function_name) |
155 | 198 | args = task.enc_args or () |
156 | 199 | kwargs = task.enc_kwargs or {} |
157 | | - immediate = task.immediate |
| 200 | + |
| 201 | + func_with_args = partial(func, *args, **kwargs) |
158 | 202 | is_coroutine_fn = asyncio.iscoroutinefunction(func) |
| 203 | + return func_with_args, is_coroutine_fn |
159 | 204 |
|
160 | | - if immediate and not is_coroutine_fn: |
161 | | - raise ValueError("Immediate tasks must be async functions.") |
162 | 205 |
|
163 | | - if ensure_coroutine: |
164 | | - if not is_coroutine_fn: |
165 | | - return sync_to_async(func)(*args, **kwargs) |
166 | | - coro = func(*args, **kwargs) |
167 | | - if immediate: |
168 | | - coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT) |
169 | | - return coro |
170 | | - else: # ensure normal function |
171 | | - if not is_coroutine_fn: |
172 | | - return lambda: func(*args, **kwargs) |
173 | | - |
174 | | - async def task_wrapper(): # asyncio.wait_for + async_to_sync requires wrapping |
175 | | - coro = func(*args, **kwargs) |
176 | | - if immediate: |
177 | | - coro = asyncio.wait_for(coro, timeout=IMMEDIATE_TIMEOUT) |
178 | | - try: |
179 | | - return await coro |
180 | | - except asyncio.TimeoutError: |
181 | | - msg_template = "Immediate task %s timed out after %s seconds." |
182 | | - error_msg = msg_template % (task.pk, IMMEDIATE_TIMEOUT) |
183 | | - _logger.info(error_msg) |
184 | | - raise RuntimeError(error_msg) |
185 | | - |
186 | | - return async_to_sync(task_wrapper) |
| 206 | +def _add_timeout_to(coro_fn, task_pk): |
| 207 | + |
| 208 | + async def _wrapper(): |
| 209 | + try: |
| 210 | + return await asyncio.wait_for(coro_fn(), timeout=IMMEDIATE_TIMEOUT) |
| 211 | + except asyncio.TimeoutError: |
| 212 | + msg_template = "Immediate task %s timed out after %s seconds." |
| 213 | + error_msg = msg_template % (task_pk, IMMEDIATE_TIMEOUT) |
| 214 | + _logger.info(error_msg) |
| 215 | + raise RuntimeError(error_msg) |
| 216 | + |
| 217 | + return _wrapper |
187 | 218 |
|
188 | 219 |
|
189 | 220 | def dispatch( |
|
0 commit comments