From 3f9e49e2e621032b8bbf14d1c06a27ea0419ba2d Mon Sep 17 00:00:00 2001 From: Rennan Cockles Date: Wed, 11 May 2022 12:54:30 -0300 Subject: [PATCH 1/5] async middleware --- fastapi_sqlalchemy/__init__.py | 5 +- fastapi_sqlalchemy/async_middleware.py | 89 ++++++++++++++++ tests/conftest.py | 33 ++++++ tests/test_async_session.py | 141 +++++++++++++++++++++++++ tests/test_session.py | 2 +- 5 files changed, 267 insertions(+), 3 deletions(-) create mode 100644 fastapi_sqlalchemy/async_middleware.py create mode 100644 tests/test_async_session.py diff --git a/fastapi_sqlalchemy/__init__.py b/fastapi_sqlalchemy/__init__.py index b0d7961..7c9a6fb 100644 --- a/fastapi_sqlalchemy/__init__.py +++ b/fastapi_sqlalchemy/__init__.py @@ -1,5 +1,6 @@ from fastapi_sqlalchemy.middleware import DBSessionMiddleware, db +from fastapi_sqlalchemy.async_middleware import AsyncDBSessionMiddleware, async_db -__all__ = ["db", "DBSessionMiddleware"] +__all__ = ["db", "DBSessionMiddleware", "async_db", "AsyncDBSessionMiddleware"] -__version__ = "0.2.1" +__version__ = "0.3.0" diff --git a/fastapi_sqlalchemy/async_middleware.py b/fastapi_sqlalchemy/async_middleware.py new file mode 100644 index 0000000..fd80cfc --- /dev/null +++ b/fastapi_sqlalchemy/async_middleware.py @@ -0,0 +1,89 @@ +from contextvars import ContextVar +from typing import Dict, Optional, Union + +from sqlalchemy.engine.url import URL +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.types import ASGIApp + +from fastapi_sqlalchemy.exceptions import ( + MissingSessionError, SessionNotInitialisedError +) + +_Session: sessionmaker = None +_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) + + +class AsyncDBSessionMiddleware(BaseHTTPMiddleware): + def __init__( + self, + app: ASGIApp, + db_url: Optional[Union[str, URL]] = None, + custom_engine: Optional[AsyncEngine] = None, + engine_args: Dict = None, + session_args: Dict = None, + commit_on_exit: bool = False, + ): + super().__init__(app) + global _Session + engine_args = engine_args or {} + self.commit_on_exit = commit_on_exit + + session_args = session_args or {} + if not custom_engine and not db_url: + raise ValueError("You need to pass a db_url or a custom_engine parameter.") + if not custom_engine: + engine = create_async_engine(db_url, future=True, **engine_args) + else: + engine = custom_engine + _Session = sessionmaker(bind=engine, class_=AsyncSession, **session_args) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async with async_db(commit_on_exit=self.commit_on_exit): + response = await call_next(request) + return response + + +class AsyncDBSessionMeta(type): + # using this metaclass means that we can access db.session as a property at a class level, + # rather than db().session + @property + def session(self) -> AsyncSession: + """Return an instance of Session local to the current async context.""" + if _Session is None: + raise SessionNotInitialisedError + + session = _session.get() + if session is None: + raise MissingSessionError + + return session + + +class AsyncDBSession(metaclass=AsyncDBSessionMeta): + def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): + self.token = None + self.session_args = session_args or {} + self.commit_on_exit = commit_on_exit + + async def __aenter__(self): + if not isinstance(_Session, sessionmaker): + raise SessionNotInitialisedError + self.token = _session.set(_Session(**self.session_args)) + return type(self) + + async def __aexit__(self, exc_type, exc_value, traceback): + sess = _session.get() + if exc_type is not None: + await sess.rollback() + + if self.commit_on_exit: + await sess.commit() + + await sess.close() + _session.reset(self.token) + + +async_db: AsyncDBSessionMeta = AsyncDBSession diff --git a/tests/conftest.py b/tests/conftest.py index 3e56090..315c057 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import pytest from fastapi import FastAPI +from httpx import AsyncClient from starlette.testclient import TestClient @@ -16,6 +17,12 @@ def client(app): yield c +@pytest.fixture +async def async_client(app): + async with AsyncClient(app=app, base_url="http://test") as c: + yield c + + @pytest.fixture def DBSessionMiddleware(): from fastapi_sqlalchemy import DBSessionMiddleware @@ -23,6 +30,13 @@ def DBSessionMiddleware(): yield DBSessionMiddleware +@pytest.fixture +def AsyncDBSessionMiddleware(): + from fastapi_sqlalchemy import AsyncDBSessionMiddleware + + yield AsyncDBSessionMiddleware + + @pytest.fixture def db(): from fastapi_sqlalchemy import db @@ -40,3 +54,22 @@ def db(): del sys.modules["fastapi_sqlalchemy.middleware"] except KeyError: pass + + +@pytest.fixture +def async_db(): + from fastapi_sqlalchemy import async_db + + yield async_db + + # force reloading of module to clear global state + + try: + del sys.modules["fastapi_sqlalchemy"] + except KeyError: + pass + + try: + del sys.modules["fastapi_sqlalchemy.async_middleware"] + except KeyError: + pass diff --git a/tests/test_async_session.py b/tests/test_async_session.py new file mode 100644 index 0000000..2e0e92a --- /dev/null +++ b/tests/test_async_session.py @@ -0,0 +1,141 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession +from starlette.middleware.base import BaseHTTPMiddleware + +from fastapi_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError + +db_url = "sqlite+aiosqlite://" + + +@pytest.mark.anyio +async def test_init(app, AsyncDBSessionMiddleware): + mw = AsyncDBSessionMiddleware(app, db_url=db_url) + assert isinstance(mw, BaseHTTPMiddleware) + + +@pytest.mark.anyio +async def test_init_required_args(app, AsyncDBSessionMiddleware): + with pytest.raises(ValueError) as exc_info: + AsyncDBSessionMiddleware(app) + + assert exc_info.value.args[0] == "You need to pass a db_url or a custom_engine parameter." + + +@pytest.mark.anyio +async def test_init_required_args_custom_engine(app, async_db, AsyncDBSessionMiddleware): + custom_engine = create_engine(db_url) + AsyncDBSessionMiddleware(app, custom_engine=custom_engine) + + +@pytest.mark.anyio +async def test_init_correct_optional_args(app, async_db, AsyncDBSessionMiddleware): + engine_args = {"echo": True} + session_args = {"autoflush": False} + + AsyncDBSessionMiddleware(app, db_url, engine_args=engine_args, session_args=session_args) + + async with async_db(): + assert not async_db.session.autoflush + + engine = async_db.session.bind + assert engine.echo + + +@pytest.mark.anyio +async def test_init_incorrect_optional_args(app, AsyncDBSessionMiddleware): + with pytest.raises(TypeError) as exc_info: + AsyncDBSessionMiddleware(app, db_url=db_url, invalid_args="test") + + assert exc_info.value.args[0] == "AsyncDBSessionMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" + + +@pytest.mark.anyio +async def test_inside_route(app, async_client, async_db, AsyncDBSessionMiddleware): + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url) + + @app.get("/") + async def test_get(): + assert isinstance(async_db.session, AsyncSession) + + await async_client.get("/") + + +@pytest.mark.anyio +async def test_inside_route_without_middleware_fails(app, async_client, async_db): + @app.get("/") + async def test_get(): + with pytest.raises(SessionNotInitialisedError): + async_db.session + + await async_client.get("/") + + +@pytest.mark.anyio +async def test_outside_of_route(app, async_db, AsyncDBSessionMiddleware): + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url) + + async with async_db(): + assert isinstance(async_db.session, AsyncSession) + + +@pytest.mark.anyio +async def test_outside_of_route_without_middleware_fails(async_db): + with pytest.raises(SessionNotInitialisedError): + async_db.session + + with pytest.raises(SessionNotInitialisedError): + async with async_db(): + pass + + +@pytest.mark.anyio +async def test_outside_of_route_without_context_fails(app, async_db, AsyncDBSessionMiddleware): + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url) + + with pytest.raises(MissingSessionError): + async_db.session + + +@pytest.mark.anyio +async def test_db_context_temporary_session_args(app, async_db, AsyncDBSessionMiddleware): + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url) + + session_args = {} + async with async_db(session_args=session_args): + assert isinstance(async_db.session, AsyncSession) + + session_args = {"autoflush": False} + async with async_db(session_args=session_args): + assert not async_db.session.autoflush + + +@pytest.mark.anyio +async def test_rollback(app, async_db, AsyncDBSessionMiddleware): + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url) + + with pytest.raises(Exception): + async with async_db(): + raise Exception + + +@pytest.mark.anyio +@pytest.mark.parametrize("commit_on_exit", [True, False]) +async def test_commit_on_exit(app, async_client, async_db, AsyncDBSessionMiddleware, commit_on_exit): + + with patch("fastapi_sqlalchemy.async_middleware._session") as session_var: + + mock_session = AsyncMock() + session_var.get.return_value = mock_session + + app.add_middleware(AsyncDBSessionMiddleware, db_url=db_url, commit_on_exit=commit_on_exit) + + @app.get("/") + async def test_get(): + pass + + await async_client.get("/") + + assert mock_session.commit.called == commit_on_exit diff --git a/tests/test_session.py b/tests/test_session.py index 78d8d4a..d63c296 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -44,7 +44,7 @@ def test_init_incorrect_optional_args(app, DBSessionMiddleware): with pytest.raises(TypeError) as exc_info: DBSessionMiddleware(app, db_url=db_url, invalid_args="test") - assert exc_info.value.args[0] == "__init__() got an unexpected keyword argument 'invalid_args'" + assert exc_info.value.args[0] == "DBSessionMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" def test_inside_route(app, client, db, DBSessionMiddleware): From c2adcc86ac7c82b2747ebbb98d2a1f7de5ad2154 Mon Sep 17 00:00:00 2001 From: Rennan Cockles Date: Wed, 11 May 2022 12:57:55 -0300 Subject: [PATCH 2/5] requirements.txt --- requirements.txt | 72 ++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4fbefbb..f073dc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,37 +1,37 @@ -appdirs==1.4.3 -atomicwrites==1.3.0 -attrs==19.3.0 -black==19.10b0 -certifi==2019.9.11 -chardet==3.0.4 -click==7.1.1 -coverage==4.5.4 -entrypoints==0.3 -fastapi==0.52.0 -flake8==3.7.9 -idna==2.8 -importlib-metadata==1.5.0 -isort==4.3.21 -mccabe==0.6.1 -more-itertools==7.2.0 -packaging==19.2 -pathspec==0.7.0 -pluggy==0.13.0 -py==1.8.0 -pycodestyle==2.5.0 -pydantic==0.32.2 -pyflakes==2.1.1 -pyparsing==2.4.2 -pytest==5.2.2 -pytest-cov==2.8.1 -PyYAML==5.3.1 -regex==2020.2.20 -requests==2.22.0 -six==1.12.0 -SQLAlchemy==1.3.10 -starlette==0.13.2 -toml==0.10.0 -typed-ast==1.4.1 -urllib3==1.25.6 +aiosqlite==0.17.0 +anyio==3.5.0 +async-generator==1.10 +atomicwrites==1.4.0 +attrs==21.4.0 +certifi==2021.10.8 +cffi==1.15.0 +charset-normalizer==2.0.12 +colorama==0.4.4 +coverage==6.3.2 +fastapi==0.77.1 +greenlet==1.1.2 +h11==0.12.0 +httpcore==0.14.7 +httpx==0.22.0 +idna==3.3 +iniconfig==1.1.1 +outcome==1.1.0 +packaging==21.3 +pluggy==1.0.0 +py==1.11.0 +pycparser==2.21 +pydantic==1.9.0 +pyparsing==3.0.9 +pytest==7.1.2 +pytest-cov==3.0.0 +requests==2.27.1 +rfc3986==1.5.0 +sniffio==1.2.0 +sortedcontainers==2.4.0 +SQLAlchemy==1.4.36 +starlette==0.19.1 +tomli==2.0.1 +trio==0.20.0 +typing_extensions==4.2.0 +urllib3==1.26.9 wcwidth==0.1.7 -zipp==3.1.0 From 55b14de161f5aa436bece4abbb89b3cc6354a42f Mon Sep 17 00:00:00 2001 From: Rennan Cockles Date: Wed, 11 May 2022 18:36:16 -0300 Subject: [PATCH 3/5] flake8 fixes --- tests/test_async_session.py | 9 +++++++-- tests/test_session.py | 5 ++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_async_session.py b/tests/test_async_session.py index 2e0e92a..8f267a1 100644 --- a/tests/test_async_session.py +++ b/tests/test_async_session.py @@ -49,7 +49,10 @@ async def test_init_incorrect_optional_args(app, AsyncDBSessionMiddleware): with pytest.raises(TypeError) as exc_info: AsyncDBSessionMiddleware(app, db_url=db_url, invalid_args="test") - assert exc_info.value.args[0] == "AsyncDBSessionMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" + assert exc_info.value.args[0] == ( + "AsyncDBSessionMiddleware.__init__() got an unexpected keyword " + "argument 'invalid_args'" + ) @pytest.mark.anyio @@ -123,7 +126,9 @@ async def test_rollback(app, async_db, AsyncDBSessionMiddleware): @pytest.mark.anyio @pytest.mark.parametrize("commit_on_exit", [True, False]) -async def test_commit_on_exit(app, async_client, async_db, AsyncDBSessionMiddleware, commit_on_exit): +async def test_commit_on_exit( + app, async_client, async_db, AsyncDBSessionMiddleware, commit_on_exit +): with patch("fastapi_sqlalchemy.async_middleware._session") as session_var: diff --git a/tests/test_session.py b/tests/test_session.py index d63c296..b935a03 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -44,7 +44,10 @@ def test_init_incorrect_optional_args(app, DBSessionMiddleware): with pytest.raises(TypeError) as exc_info: DBSessionMiddleware(app, db_url=db_url, invalid_args="test") - assert exc_info.value.args[0] == "DBSessionMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" + assert exc_info.value.args[0] == ( + "DBSessionMiddleware.__init__() got an unexpected keyword " + "argument 'invalid_args'" + ) def test_inside_route(app, client, db, DBSessionMiddleware): From 766fb1a7aa54f053af14dd801bf3cd164308e213 Mon Sep 17 00:00:00 2001 From: Rennan Cockles Date: Wed, 11 May 2022 18:44:28 -0300 Subject: [PATCH 4/5] pytest fixes --- tests/test_async_session.py | 3 +-- tests/test_session.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_async_session.py b/tests/test_async_session.py index 8f267a1..c5c1899 100644 --- a/tests/test_async_session.py +++ b/tests/test_async_session.py @@ -50,8 +50,7 @@ async def test_init_incorrect_optional_args(app, AsyncDBSessionMiddleware): AsyncDBSessionMiddleware(app, db_url=db_url, invalid_args="test") assert exc_info.value.args[0] == ( - "AsyncDBSessionMiddleware.__init__() got an unexpected keyword " - "argument 'invalid_args'" + "__init__() got an unexpected keyword argument 'invalid_args'" ) diff --git a/tests/test_session.py b/tests/test_session.py index b935a03..7f680da 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -45,8 +45,7 @@ def test_init_incorrect_optional_args(app, DBSessionMiddleware): DBSessionMiddleware(app, db_url=db_url, invalid_args="test") assert exc_info.value.args[0] == ( - "DBSessionMiddleware.__init__() got an unexpected keyword " - "argument 'invalid_args'" + "__init__() got an unexpected keyword argument 'invalid_args'" ) From 0a5c5cf911ffa99e428099a7f930e6247c75875d Mon Sep 17 00:00:00 2001 From: Rennan Cockles Date: Wed, 11 May 2022 18:45:41 -0300 Subject: [PATCH 5/5] upgrade minimum python version for async --- .github/workflows/ci.yml | 32 ++++++++++++++++---------------- .github/workflows/release.yml | 14 +++++++------- README.rst | 8 ++++---- requirements.txt | 10 ++++++++++ setup.py | 5 ++--- 5 files changed, 39 insertions(+), 30 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c975c72..b28162c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,15 +6,18 @@ on: push: branches: - master - + jobs: test: name: test runs-on: ${{ matrix.os }} strategy: matrix: - build: [linux_3.8, windows_3.8, mac_3.8, linux_3.7] + build: [linux_3.9, linux_3.8, windows_3.8, mac_3.8] include: + - build: linux_3.9 + os: ubuntu-latest + python: 3.9 - build: linux_3.8 os: ubuntu-latest python: 3.8 @@ -24,9 +27,6 @@ jobs: - build: mac_3.8 os: macos-latest python: 3.8 - - build: linux_3.7 - os: ubuntu-latest - python: 3.7 steps: - name: Checkout repository uses: actions/checkout@v2 @@ -35,39 +35,39 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python }} - + - name: Install dependencies run: | python -m pip install --upgrade pip wheel pip install -r requirements.txt - + # test all the builds apart from linux_3.8... - name: Test with pytest if: matrix.build != 'linux_3.8' run: pytest - + # only do the test coverage for linux_3.8 - name: Produce coverage report if: matrix.build == 'linux_3.8' run: pytest --cov=fastapi_sqlalchemy --cov-report=xml - + - name: Upload coverage report if: matrix.build == 'linux_3.8' uses: codecov/codecov-action@v1 with: file: ./coverage.xml - + lint: name: lint runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v2 - + - name: Set up Python uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: pip install flake8 @@ -85,13 +85,13 @@ jobs: - name: Set up Python uses: actions/setup-python@v1 with: - python-version: 3.7 - + python-version: 3.8 + - name: Install dependencies - # isort needs all of the packages to be installed so it can + # isort needs all of the packages to be installed so it can # tell which are third party and which are first party run: pip install -r requirements.txt - + - name: Check formatting of imports run: isort --check-only --diff --verbose diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3c5aea7..eb25f18 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,18 +11,18 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v2 - - - name: Set up Python 3.7 + + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install the dependencies run: pip install --upgrade pip wheel setuptools - + - name: Build the distributions run: python setup.py sdist bdist_wheel - + - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@master with: @@ -36,10 +36,10 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v2 - + - name: Get the version run: echo ::set-env name=VERSION::${GITHUB_REF#refs/tags/} - + - name: Create release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.rst b/README.rst index 9dfb6c5..096dd87 100644 --- a/README.rst +++ b/README.rst @@ -40,7 +40,7 @@ Usage inside of a route app.add_middleware(DBSessionMiddleware, db_url="sqlite://") - # once the middleware is applied, any route can then access the database session + # once the middleware is applied, any route can then access the database session # from the global ``db`` @app.get("/users") @@ -49,7 +49,7 @@ Usage inside of a route return users -Note that the session object provided by ``db.session`` is based on the Python3.7+ ``ContextVar``. This means that +Note that the session object provided by ``db.session`` is based on the Python3.8+ ``ContextVar``. This means that each session is linked to the individual request context in which it was created. Usage outside of a route @@ -82,7 +82,7 @@ Sometimes it is useful to be able to access the database outside the context of """Count the number of users in the database and save it into the user_counts table.""" # we are outside of a request context, therefore we cannot rely on ``DBSessionMiddleware`` - # to create a database session for us. Instead, we can use the same ``db`` object and + # to create a database session for us. Instead, we can use the same ``db`` object and # use it as a context manager, like so: with db(): @@ -90,7 +90,7 @@ Sometimes it is useful to be able to access the database outside the context of db.session.add(UserCount(user_count)) db.session.commit() - + # no longer able to access a database session once the db() context manager has ended return users diff --git a/requirements.txt b/requirements.txt index f073dc7..ee2b992 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,24 +3,34 @@ anyio==3.5.0 async-generator==1.10 atomicwrites==1.4.0 attrs==21.4.0 +black==22.3.0 certifi==2021.10.8 cffi==1.15.0 charset-normalizer==2.0.12 +click==8.1.3 colorama==0.4.4 coverage==6.3.2 fastapi==0.77.1 +flake8==4.0.1 greenlet==1.1.2 h11==0.12.0 httpcore==0.14.7 httpx==0.22.0 idna==3.3 iniconfig==1.1.1 +isort==5.10.1 +mccabe==0.6.1 +mypy-extensions==0.4.3 outcome==1.1.0 packaging==21.3 +pathspec==0.9.0 +platformdirs==2.5.2 pluggy==1.0.0 py==1.11.0 +pycodestyle==2.8.0 pycparser==2.21 pydantic==1.9.0 +pyflakes==2.4.0 pyparsing==3.0.9 pytest==7.1.2 pytest-cov==3.0.0 diff --git a/setup.py b/setup.py index 2b97731..919c3a0 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ packages=["fastapi_sqlalchemy"], package_data={"fastapi_sqlalchemy": ["py.typed"]}, zip_safe=False, - python_requires=">=3.7", + python_requires=">=3.8", install_requires=["starlette>=0.12.9", "SQLAlchemy>=1.2"], classifiers=[ "Development Status :: 4 - Beta", @@ -34,9 +34,8 @@ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: WWW/HTTP :: HTTP Servers",