Skip to content

Commit 1420fb0

Browse files
authored
Improve truncate_models for cascade delete (#2100)
* Improve truncate_models for cascade delete * Fix
1 parent 0b516dd commit 1420fb0

File tree

5 files changed

+222
-12
lines changed

5 files changed

+222
-12
lines changed

docs/contrib/unittest.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,12 @@ Utility Functions
246246
truncate_all_models
247247
-------------------
248248

249-
Truncate all model tables in the current context:
249+
Truncate all model tables in the current context. The function handles foreign key
250+
constraints automatically:
251+
252+
- **PostgreSQL**: Uses a single ``TRUNCATE ... CASCADE`` statement (fast, single round-trip).
253+
- **Other databases**: Deletes in topological order — child tables are emptied before the
254+
parent tables they reference, avoiding FK constraint violations.
250255

251256
.. code-block:: python
252257
@@ -257,7 +262,7 @@ Truncate all model tables in the current context:
257262
# Create some data
258263
await User.create(name="Test")
259264
260-
# Truncate all tables
265+
# Truncate all tables (FK-safe)
261266
await truncate_all_models()
262267
263268
# Tables are now empty

tests/cli/test_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ async def fake_migrate(**kwargs) -> None:
203203

204204
@pytest.mark.asyncio
205205
async def test_upgrade_alias(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
206+
_write_package(tmp_path, "cli_app")
206207
module_name = _write_settings(
207208
tmp_path,
208209
"""
@@ -234,6 +235,7 @@ async def fake_migrate(**kwargs) -> None:
234235

235236
@pytest.mark.asyncio
236237
async def test_downgrade_defaults_to_first(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
238+
_write_package(tmp_path, "cli_app")
237239
module_name = _write_settings(
238240
tmp_path,
239241
"""

tests/fields/test_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ async def test_datetime_create(db):
8080
obj = await model.get(id=obj0.id)
8181
assert obj.datetime == now
8282
assert obj.datetime_null is None
83-
assert obj.datetime_auto - now < timedelta(microseconds=20000)
84-
assert obj.datetime_add - now < timedelta(microseconds=20000)
83+
assert obj.datetime_auto - now < timedelta(seconds=1)
84+
assert obj.datetime_add - now < timedelta(seconds=1)
8585
datetime_auto = obj.datetime_auto
8686
sleep(0.012)
8787
await obj.save()

tests/test_truncate.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Tests for truncate_all_models() and _topological_sort_models().
3+
4+
Verifies FK-aware truncation ordering and the PostgreSQL TRUNCATE CASCADE path.
5+
"""
6+
7+
import pytest
8+
9+
from tests.testmodels import Employee, Event, MinRelation, Reporter, Team, Tournament
10+
from tortoise import Tortoise
11+
from tortoise.contrib.test import _topological_sort_models, truncate_all_models
12+
13+
# ---------------------------------------------------------------------------
14+
# _topological_sort_models — unit tests on real model metadata
15+
# ---------------------------------------------------------------------------
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_topological_sort_children_before_parents(db):
20+
"""Event (FK→Tournament) must come before Tournament in delete order."""
21+
sorted_models = _topological_sort_models([Tournament, Event])
22+
assert sorted_models.index(Event) < sorted_models.index(Tournament)
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_topological_sort_input_order_independent(db):
27+
"""Result must be the same regardless of input order."""
28+
order_a = _topological_sort_models([Tournament, Event])
29+
order_b = _topological_sort_models([Event, Tournament])
30+
assert order_a == order_b
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_topological_sort_self_referential_fk(db):
35+
"""Self-referential FK (Employee→Employee) must not cause infinite loop."""
36+
result = _topological_sort_models([Employee])
37+
assert result == [Employee]
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_topological_sort_no_fk_models(db):
42+
"""Models without FK relationships are still included."""
43+
result = _topological_sort_models([Team])
44+
assert result == [Team]
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_topological_sort_all_models(db):
49+
"""Sorting all registered models succeeds and includes every model."""
50+
all_models = list(Tortoise.apps.get_models_iterable())
51+
sorted_models = _topological_sort_models(all_models)
52+
assert set(sorted_models) == set(all_models)
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_topological_sort_multi_level_chain(db):
57+
"""MinRelation→Tournament and MinRelation→Team: MinRelation before both parents."""
58+
sorted_models = _topological_sort_models([Tournament, Team, MinRelation])
59+
assert sorted_models.index(MinRelation) < sorted_models.index(Tournament)
60+
assert sorted_models.index(MinRelation) < sorted_models.index(Team)
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_topological_sort_multiple_fks_on_one_model(db):
65+
"""Event has FKs to both Tournament and Reporter — must come before both."""
66+
sorted_models = _topological_sort_models([Tournament, Reporter, Event])
67+
assert sorted_models.index(Event) < sorted_models.index(Tournament)
68+
assert sorted_models.index(Event) < sorted_models.index(Reporter)
69+
70+
71+
# ---------------------------------------------------------------------------
72+
# truncate_all_models — integration tests against real DB
73+
# ---------------------------------------------------------------------------
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_truncate_empty_db(db):
78+
"""Truncating when tables are empty should succeed without error."""
79+
await truncate_all_models()
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_truncate_clears_data(db):
84+
"""Data created before truncation is gone after truncation."""
85+
tournament = await Tournament.create(name="Test Tournament")
86+
await Event.create(name="Test Event", tournament=tournament)
87+
88+
await truncate_all_models()
89+
90+
assert await Tournament.all().count() == 0
91+
assert await Event.all().count() == 0
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_truncate_with_fk_constraints(db):
96+
"""Truncation succeeds even with FK constraints (child→parent)."""
97+
t = await Tournament.create(name="T1")
98+
await Event.create(name="E1", tournament=t)
99+
await Event.create(name="E2", tournament=t)
100+
101+
# This would fail with arbitrary order on strict FK enforcement
102+
await truncate_all_models()
103+
104+
assert await Event.all().count() == 0
105+
assert await Tournament.all().count() == 0
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_truncate_with_self_referential_fk(db):
110+
"""Self-referential FK (Employee→Employee) doesn't break truncation."""
111+
boss = await Employee.create(name="Boss")
112+
await Employee.create(name="Worker", manager=boss)
113+
114+
await truncate_all_models()
115+
116+
assert await Employee.all().count() == 0
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_truncate_raises_when_apps_not_loaded(db_simple):
121+
"""truncate_all_models raises ValueError when apps aren't loaded."""
122+
from tortoise.context import get_current_context
123+
124+
ctx = get_current_context()
125+
saved_apps = ctx._apps
126+
ctx._apps = {}
127+
try:
128+
with pytest.raises(ValueError, match="apps are not loaded"):
129+
await truncate_all_models()
130+
finally:
131+
ctx._apps = saved_apps

tortoise/contrib/test/__init__.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@ async def test_sqlite_only(db):
2828
import typing
2929
from collections.abc import Callable, Coroutine
3030
from functools import partial, wraps
31-
from typing import ParamSpec, TypeVar, cast
31+
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast
3232
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless
3333

3434
from tortoise import Tortoise
3535
from tortoise.connection import get_connection
3636
from tortoise.context import TortoiseContext, tortoise_test_context
3737

38+
if TYPE_CHECKING:
39+
from tortoise.models import Model
40+
3841
T = TypeVar("T")
3942
P = ParamSpec("P")
4043
AsyncFunc = Callable[P, Coroutine[None, None, T]]
@@ -70,19 +73,88 @@ async def truncate_all_models() -> None:
7073
This is a utility function for test cleanup that deletes all rows from
7174
all registered model tables.
7275
73-
Note: This is a naive implementation that may fail with M2M relations
74-
and non-cascade foreign keys.
76+
On PostgreSQL, uses ``TRUNCATE ... CASCADE`` for a single fast statement.
77+
On other databases, deletes in topological (FK dependency) order so that
78+
child rows are removed before parent rows they reference.
7579
7680
Raises:
7781
ValueError: If Tortoise.apps is not loaded.
7882
"""
7983
if not Tortoise.apps:
8084
raise ValueError("apps are not loaded")
81-
for model in Tortoise.apps.get_models_iterable():
82-
quote_char = model._meta.db.query_class.SQL_CONTEXT.quote_char
83-
await model._meta.db.execute_script(
84-
f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec
85-
)
85+
86+
models = list(Tortoise.apps.get_models_iterable())
87+
88+
if not models:
89+
return
90+
91+
db = models[0]._meta.db
92+
dialect = db.capabilities.dialect
93+
94+
if dialect == "postgres":
95+
# PostgreSQL supports TRUNCATE with CASCADE — single statement, fast
96+
tables = ", ".join(f'"{m._meta.db_table}"' for m in models)
97+
await db.execute_script(f"TRUNCATE {tables} CASCADE")
98+
else:
99+
# For other dialects, topologically sort by FK dependencies (children first)
100+
sorted_models = _topological_sort_models(models)
101+
102+
# Disable FK checks to handle self-referential and circular FK constraints
103+
if dialect == "mysql":
104+
await db.execute_script("SET FOREIGN_KEY_CHECKS = 0")
105+
elif dialect == "sqlite":
106+
await db.execute_script("PRAGMA foreign_keys = OFF")
107+
108+
try:
109+
for model in sorted_models:
110+
quote_char = model._meta.db.query_class.SQL_CONTEXT.quote_char
111+
await model._meta.db.execute_script(
112+
f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec
113+
)
114+
finally:
115+
if dialect == "mysql":
116+
await db.execute_script("SET FOREIGN_KEY_CHECKS = 1")
117+
elif dialect == "sqlite":
118+
await db.execute_script("PRAGMA foreign_keys = ON")
119+
120+
121+
def _topological_sort_models(models: list[type[Model]]) -> list[type[Model]]:
122+
"""Sort models so children come before parents (safe delete order).
123+
124+
Uses Kahn's algorithm on FK dependencies. Models that depend on others
125+
via ForeignKey are placed *before* the models they reference, ensuring
126+
child rows are deleted before parent rows.
127+
"""
128+
from tortoise.fields.relational import ForeignKeyFieldInstance
129+
130+
model_set = set(models)
131+
# Build adjacency for delete order: parent -> children that must be deleted first
132+
# If Event has FK to Tournament, then Tournament depends on Event being deleted first
133+
deps: dict[type[Model], set[type[Model]]] = {m: set() for m in models}
134+
for model in models:
135+
for field in model._meta.fields_map.values():
136+
if isinstance(field, ForeignKeyFieldInstance):
137+
related = field.related_model
138+
if related in model_set and related is not model:
139+
deps[related].add(model)
140+
141+
# Kahn's algorithm — emit models whose deps are already emitted
142+
sorted_models: list[type[Model]] = []
143+
no_deps = [m for m in models if not deps[m]]
144+
while no_deps:
145+
m = no_deps.pop()
146+
sorted_models.append(m)
147+
for other in models:
148+
deps[other].discard(m)
149+
if not deps[other] and other not in sorted_models and other not in no_deps:
150+
no_deps.append(other)
151+
152+
# Append any remaining (circular deps — fallback)
153+
for m in models:
154+
if m not in sorted_models:
155+
sorted_models.append(m)
156+
157+
return sorted_models
86158

87159

88160
_FT = TypeVar("_FT", bound=Callable[..., typing.Any])

0 commit comments

Comments
 (0)