|
6 | 6 | import inspect |
7 | 7 | import logging |
8 | 8 | import warnings |
| 9 | +import functools |
9 | 10 | import contextlib |
10 | | -from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine |
| 11 | +from types import TracebackType |
| 12 | +from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine |
11 | 13 | from importlib.util import find_spec |
12 | 14 |
|
13 | | -from ._types import CoroType, FuncType, TypeGuard |
| 15 | +from ._types import CoroType, FuncType, TypeGuard, ExcMapping |
14 | 16 |
|
15 | 17 | if TYPE_CHECKING: |
16 | 18 | from typing_extensions import TypeGuard |
@@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None: |
139 | 141 |
|
140 | 142 | def is_dict(obj: object) -> TypeGuard[dict[object, object]]: |
141 | 143 | return isinstance(obj, dict) |
| 144 | + |
| 145 | + |
| 146 | +# TODO: improve typing |
| 147 | +class MaybeAsyncContextDecorator(contextlib.ContextDecorator): |
| 148 | + """`ContextDecorator` compatible with sync/async functions.""" |
| 149 | + |
| 150 | + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type: ignore[override] |
| 151 | + @functools.wraps(func) |
| 152 | + async def async_inner(*args: Any, **kwargs: Any) -> object: |
| 153 | + async with self._recreate_cm(): # type: ignore[attr-defined] |
| 154 | + return await func(*args, **kwargs) |
| 155 | + |
| 156 | + @functools.wraps(func) |
| 157 | + def sync_inner(*args: Any, **kwargs: Any) -> object: |
| 158 | + with self._recreate_cm(): # type: ignore[attr-defined] |
| 159 | + return func(*args, **kwargs) |
| 160 | + |
| 161 | + if is_coroutine(func): |
| 162 | + return async_inner |
| 163 | + else: |
| 164 | + return sync_inner |
| 165 | + |
| 166 | + |
| 167 | +class ExcConverter(MaybeAsyncContextDecorator): |
| 168 | + """`MaybeAsyncContextDecorator` to convert exceptions.""" |
| 169 | + |
| 170 | + def __init__(self, exc_mapping: ExcMapping) -> None: |
| 171 | + self._exc_mapping = exc_mapping |
| 172 | + |
| 173 | + def __enter__(self) -> 'ExcConverter': |
| 174 | + return self |
| 175 | + |
| 176 | + def __exit__( |
| 177 | + self, |
| 178 | + exc_type: Optional[Type[BaseException]], |
| 179 | + exc: Optional[BaseException], |
| 180 | + exc_tb: Optional[TracebackType], |
| 181 | + ) -> None: |
| 182 | + if exc is not None and exc_type is not None: |
| 183 | + target_exc_type = self._exc_mapping.get(exc_type) |
| 184 | + if target_exc_type is not None: |
| 185 | + raise target_exc_type() from exc |
| 186 | + |
| 187 | + async def __aenter__(self) -> 'ExcConverter': |
| 188 | + return self.__enter__() |
| 189 | + |
| 190 | + async def __aexit__( |
| 191 | + self, |
| 192 | + exc_type: Optional[Type[BaseException]], |
| 193 | + exc: Optional[BaseException], |
| 194 | + exc_tb: Optional[TracebackType], |
| 195 | + ) -> None: |
| 196 | + self.__exit__(exc_type, exc, exc_tb) |
0 commit comments