From 39b7fcb5eab89ca6a75be4cb80ca9fb60d341bf4 Mon Sep 17 00:00:00 2001 From: mghh Date: Mon, 30 Jan 2023 09:44:03 +0100 Subject: [PATCH] Allow session rollbacks on HTTPExceptions By default fastapi handles HTTPExceptions and does not raise inside the fastapi-sqlalchemy middleware. We introduce two new parameters to control rollback behaviour: - rollback_on_client_error rolls back on 40x http exceptions - rollback_on_server_error rolls back on 50x http exceptions Additional we add support to force rollback if used outside a route in context manager by setting the attribute `force_rollback` on the session context (see tests/test_session.py). --- .gitignore | 8 +++- fastapi_sqlalchemy/middleware.py | 63 +++++++++++++++++++++++-- tests/test_session.py | 81 +++++++++++++++++++++++++++----- 3 files changed, 136 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index df43617..e5da237 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,10 @@ __pycache__/ htmlcov/ test.py *.egg-info -coverage.xml \ No newline at end of file +coverage.xml +bin/ +include/ +lib/ +lib64 +pyvenv.cfg + diff --git a/fastapi_sqlalchemy/middleware.py b/fastapi_sqlalchemy/middleware.py index 6422fc1..931c551 100644 --- a/fastapi_sqlalchemy/middleware.py +++ b/fastapi_sqlalchemy/middleware.py @@ -24,13 +24,34 @@ def __init__( engine_args: Dict = None, session_args: Dict = None, commit_on_exit: bool = False, + rollback_on_client_error: bool = False, + rollback_on_server_error: bool = False, ): + """Initialize middleware. + + Args: + rollback_on_client_error: + Fastapi does handle http client errors + (see https://httpwg.org/specs/rfc9110.html#status.4xx) + and returns a valid response without raising inside + the `DBSessionMiddleware.dispatch` method. + If `rollback_on_client_errors` is true the session + gets rolledback even if no Exception is raised inside + the contextmanager. + rollback_on_server_error: + See above `rollback_on_client_error`. The session + is rolled back on 5xx HTTP-Codes + (https://httpwg.org/specs/rfc9110.html#status.5xx). + """ super().__init__(app) global _Session engine_args = engine_args or {} self.commit_on_exit = commit_on_exit + self.rollback_on_client_error = rollback_on_client_error + self.rollback_on_server_error = rollback_on_server_error session_args = session_args or {} + if not custom_engine and not db_url: raise ValueError("You need to pass a db_url or a custom_engine parameter.") if not custom_engine: @@ -40,8 +61,24 @@ def __init__( _Session = sessionmaker(bind=engine, **session_args) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): - with db(commit_on_exit=self.commit_on_exit): + with db(commit_on_exit=self.commit_on_exit) as dbsession_context: response = await call_next(request) + + if response and hasattr(response, "status_code") and response.status_code: + # I am not deep enough in fastapi. In allmost all cases + # status_code should be int. It may be possible that third + # party usage of custom HTTPException could set status codes + # as strings. + status_code = int(response.status_code) + is_client_error = status_code >= 400 and status_code < 500 + is_server_error = status_code >= 500 + + if is_client_error and self.rollback_on_client_error: + dbsession_context.force_rollback = True + + if is_server_error and self.rollback_on_server_error: + dbsession_context.force_rollback = True + return response @@ -67,18 +104,38 @@ def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): self.session_args = session_args or {} self.commit_on_exit = commit_on_exit + # The code using this context could signal that + # the session should rolled back in case of + # conditions or errors which are not raised + # but handled internally. + self.force_rollback = False + def __enter__(self): if not isinstance(_Session, sessionmaker): raise SessionNotInitialisedError self.token = _session.set(_Session(**self.session_args)) - return type(self) + # We return `self` here to make the context + # available to inner code for enabling `force_rollback` + # on the session. + # Before this change the return value was `type(self)`. + # In allmost all examples the return was not used + # and this refactoring should not break existing code. + return self def __exit__(self, exc_type, exc_value, traceback): sess = _session.get() + + is_rolled_back = False + if exc_type is not None: sess.rollback() + is_rolled_back = True + + if self.force_rollback and not is_rolled_back: + sess.rollback() + is_rolled_back = True - if self.commit_on_exit: + if self.commit_on_exit and not is_rolled_back: sess.commit() sess.close() diff --git a/tests/test_session.py b/tests/test_session.py index 78d8d4a..50bea44 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,6 +1,8 @@ +import contextlib from unittest.mock import Mock, patch import pytest +from fastapi.exceptions import HTTPException from sqlalchemy import create_engine from sqlalchemy.orm import Session from starlette.middleware.base import BaseHTTPMiddleware @@ -10,6 +12,16 @@ db_url = "sqlite://" +@contextlib.contextmanager +def mock_session(): + patcher = patch("fastapi_sqlalchemy.middleware._session") + mock_session = Mock() + mocked = patcher.start() + mocked.get.return_value = mock_session + yield mock_session + patcher.stop() + + def test_init(app, DBSessionMiddleware): mw = DBSessionMiddleware(app, db_url=db_url) assert isinstance(mw, BaseHTTPMiddleware) @@ -103,24 +115,37 @@ def test_db_context_temporary_session_args(app, db, DBSessionMiddleware): assert not db.session.expire_on_commit -def test_rollback(app, db, DBSessionMiddleware): - # pytest-cov shows that the line in db.__exit__() rolling back the db session - # when there is an Exception is run correctly. However, it would be much better - # if we could demonstrate somehow that db.session.rollback() was called e.g. once +def test_rollback_on_exception(app, db, DBSessionMiddleware): app.add_middleware(DBSessionMiddleware, db_url=db_url) - with pytest.raises(Exception): - with db(): - raise Exception + with mock_session() as session: + with pytest.raises(Exception): + with db(): + raise Exception + + assert session.rollback.called is True + + +def test_rollback_could_be_forced(app, db, DBSessionMiddleware): + app.add_middleware(DBSessionMiddleware, db_url=db_url) + + with mock_session() as session: + with db() as session_context: + session_context.force_rollback = False + + assert session.rollback.called is False + + with mock_session() as session: + with db() as session_context: + session_context.force_rollback = True + + assert session.rollback.called is True @pytest.mark.parametrize("commit_on_exit", [True, False]) def test_commit_on_exit(app, client, db, DBSessionMiddleware, commit_on_exit): - with patch("fastapi_sqlalchemy.middleware._session") as session_var: - - mock_session = Mock() - session_var.get.return_value = mock_session + with mock_session() as session: app.add_middleware(DBSessionMiddleware, db_url=db_url, commit_on_exit=commit_on_exit) @@ -130,4 +155,36 @@ def test_get(): client.get("/") - assert mock_session.commit.called == commit_on_exit + assert session.commit.called == commit_on_exit + + +@pytest.mark.parametrize( + "client_error, server_error, status_code, expected", + [ + (True, False, 400, True), + (False, False, 400, False), + (True, False, 500, False), + (False, True, 502, True), + (False, True, 422, False), + ], +) +def test_rollback_on_http_exceptions( + app, client, db, DBSessionMiddleware, client_error, server_error, status_code, expected +): + + with mock_session() as session: + + app.add_middleware( + DBSessionMiddleware, + db_url=db_url, + rollback_on_client_error=client_error, + rollback_on_server_error=server_error, + ) + + @app.get("/") + def test_get(): + raise HTTPException(status_code=status_code) + + client.get("/") + + assert session.rollback.called is expected