Skip to content

Commit e1bf60f

Browse files
authored
feat: add schedule source for psqlpy and aiopg (#18)
2 parents 618cfcd + fc3c45e commit e1bf60f

File tree

12 files changed

+559
-46
lines changed

12 files changed

+559
-46
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ ignore = [
129129
# Conflicted rules
130130
"D203", # with D211
131131
"D212", # with D213
132+
"COM812", # with formatter
132133
]
133134

134135
[tool.ruff.lint.per-file-ignores]
@@ -146,6 +147,8 @@ ignore = [
146147
"S608",
147148

148149
"RUF",
150+
151+
"PLR2004", # magic numbers in tests
149152
]
150153
"tests/test_linting.py" = [
151154
"S603", # subprocess usage
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from taskiq_pg._internal.broker import BasePostgresBroker
2+
from taskiq_pg._internal.result_backend import BasePostgresResultBackend
3+
from taskiq_pg._internal.schedule_source import BasePostgresScheduleSource
4+
5+
6+
__all__ = [
7+
"BasePostgresBroker",
8+
"BasePostgresResultBackend",
9+
"BasePostgresScheduleSource",
10+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import typing as tp
4+
5+
from taskiq import ScheduleSource
6+
7+
8+
if tp.TYPE_CHECKING:
9+
from taskiq.abc.broker import AsyncBroker
10+
11+
12+
class BasePostgresScheduleSource(ScheduleSource):
13+
def __init__(
14+
self,
15+
broker: AsyncBroker,
16+
dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres",
17+
table_name: str = "taskiq_schedules",
18+
**connect_kwargs: tp.Any,
19+
) -> None:
20+
"""
21+
Initialize the PostgreSQL scheduler source.
22+
23+
Sets up a scheduler source that stores scheduled tasks in a PostgreSQL database.
24+
This scheduler source manages task schedules, allowing for persistent storage and retrieval of scheduled tasks
25+
across application restarts.
26+
27+
Args:
28+
dsn: PostgreSQL connection string
29+
table_name: Name of the table to store scheduled tasks. Will be created automatically if it doesn't exist.
30+
broker: The TaskIQ broker instance to use for finding and managing tasks.
31+
Required if startup_schedule is provided.
32+
**connect_kwargs: Additional keyword arguments passed to the database connection pool.
33+
34+
"""
35+
self._broker: tp.Final = broker
36+
self._dsn: tp.Final = dsn
37+
self._table_name: tp.Final = table_name
38+
self._connect_kwargs: tp.Final = connect_kwargs
39+
40+
@property
41+
def dsn(self) -> str | None:
42+
"""
43+
Get the DSN string.
44+
45+
Returns the DSN string or None if not set.
46+
"""
47+
if callable(self._dsn):
48+
return self._dsn()
49+
return self._dsn

src/taskiq_pg/aiopg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from taskiq_pg.aiopg.result_backend import AiopgResultBackend
2+
from taskiq_pg.aiopg.schedule_source import AiopgScheduleSource
23

34

45
__all__ = [
56
"AiopgResultBackend",
7+
"AiopgScheduleSource",
68
]

src/taskiq_pg/aiopg/broker.py

Whitespace-only changes.

src/taskiq_pg/aiopg/queries.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,31 @@
2929
DELETE_RESULT_QUERY = """
3030
DELETE FROM {} WHERE task_id = %s
3131
"""
32+
33+
CREATE_SCHEDULES_TABLE_QUERY = """
34+
CREATE TABLE IF NOT EXISTS {} (
35+
id UUID PRIMARY KEY,
36+
task_name VARCHAR(100) NOT NULL,
37+
schedule JSONB NOT NULL,
38+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
39+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
40+
);
41+
"""
42+
43+
INSERT_SCHEDULE_QUERY = """
44+
INSERT INTO {} (id, task_name, schedule)
45+
VALUES (%s, %s, %s)
46+
ON CONFLICT (id) DO UPDATE
47+
SET task_name = EXCLUDED.task_name,
48+
schedule = EXCLUDED.schedule,
49+
updated_at = NOW();
50+
"""
51+
52+
SELECT_SCHEDULES_QUERY = """
53+
SELECT id, task_name, schedule
54+
FROM {};
55+
"""
56+
57+
DELETE_ALL_SCHEDULES_QUERY = """
58+
DELETE FROM {};
59+
"""
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import uuid
2+
from logging import getLogger
3+
4+
from aiopg import Pool, create_pool
5+
from pydantic import ValidationError
6+
from taskiq import ScheduledTask
7+
8+
from taskiq_pg import exceptions
9+
from taskiq_pg._internal import BasePostgresScheduleSource
10+
from taskiq_pg.aiopg.queries import (
11+
CREATE_SCHEDULES_TABLE_QUERY,
12+
DELETE_ALL_SCHEDULES_QUERY,
13+
INSERT_SCHEDULE_QUERY,
14+
SELECT_SCHEDULES_QUERY,
15+
)
16+
17+
18+
logger = getLogger("taskiq_pg.aiopg_schedule_source")
19+
20+
21+
class AiopgScheduleSource(BasePostgresScheduleSource):
22+
"""Schedule source that uses aiopg to store schedules in PostgreSQL."""
23+
24+
_database_pool: Pool
25+
26+
async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> None:
27+
"""Update schedules in the database on startup: truncate table and insert new ones."""
28+
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
29+
await cursor.execute(DELETE_ALL_SCHEDULES_QUERY.format(self._table_name))
30+
for schedule in schedules:
31+
await cursor.execute(
32+
INSERT_SCHEDULE_QUERY.format(self._table_name),
33+
[
34+
schedule.schedule_id,
35+
schedule.task_name,
36+
schedule.model_dump_json(
37+
exclude={"schedule_id", "task_name"},
38+
),
39+
],
40+
)
41+
42+
def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]:
43+
"""Extract schedules from the broker's registered tasks."""
44+
scheduled_tasks_for_creation: list[ScheduledTask] = []
45+
for task_name, task in self._broker.get_all_tasks().items():
46+
if "schedule" not in task.labels:
47+
logger.debug("Task %s has no schedule, skipping", task_name)
48+
continue
49+
if not isinstance(task.labels["schedule"], list):
50+
logger.warning(
51+
"Schedule for task %s is not a list, skipping",
52+
task_name,
53+
)
54+
continue
55+
for schedule in task.labels["schedule"]:
56+
try:
57+
new_schedule = ScheduledTask.model_validate(
58+
{
59+
"task_name": task_name,
60+
"labels": schedule.get("labels", {}),
61+
"args": schedule.get("args", []),
62+
"kwargs": schedule.get("kwargs", {}),
63+
"schedule_id": str(uuid.uuid4()),
64+
"cron": schedule.get("cron", None),
65+
"cron_offset": schedule.get("cron_offset", None),
66+
"time": schedule.get("time", None),
67+
},
68+
)
69+
scheduled_tasks_for_creation.append(new_schedule)
70+
except ValidationError:
71+
logger.exception(
72+
"Schedule for task %s is not valid, skipping",
73+
task_name,
74+
)
75+
continue
76+
return scheduled_tasks_for_creation
77+
78+
async def startup(self) -> None:
79+
"""
80+
Initialize the schedule source.
81+
82+
Construct new connection pool, create new table for schedules if not exists
83+
and fill table with schedules from task labels.
84+
"""
85+
try:
86+
self._database_pool = await create_pool(
87+
dsn=self.dsn,
88+
**self._connect_kwargs,
89+
)
90+
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
91+
await cursor.execute(CREATE_SCHEDULES_TABLE_QUERY.format(self._table_name))
92+
scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks()
93+
await self._update_schedules_on_startup(scheduled_tasks_for_creation)
94+
except Exception as error:
95+
raise exceptions.DatabaseConnectionError(str(error)) from error
96+
97+
async def shutdown(self) -> None:
98+
"""Close the connection pool."""
99+
if getattr(self, "_database_pool", None) is not None:
100+
self._database_pool.close()
101+
102+
async def get_schedules(self) -> list["ScheduledTask"]:
103+
"""Fetch schedules from the database."""
104+
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
105+
await cursor.execute(
106+
SELECT_SCHEDULES_QUERY.format(self._table_name),
107+
)
108+
schedules, rows = [], await cursor.fetchall()
109+
for schedule_id, task_name, schedule in rows:
110+
schedules.append(
111+
ScheduledTask.model_validate(
112+
{
113+
"schedule_id": str(schedule_id),
114+
"task_name": task_name,
115+
"labels": schedule["labels"],
116+
"args": schedule["args"],
117+
"kwargs": schedule["kwargs"],
118+
"cron": schedule["cron"],
119+
"cron_offset": schedule["cron_offset"],
120+
"time": schedule["time"],
121+
},
122+
),
123+
)
124+
return schedules

src/taskiq_pg/asyncpg/schedule_source.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import json
2-
import typing as tp
32
import uuid
43
from logging import getLogger
54

65
import asyncpg
76
from pydantic import ValidationError
8-
from taskiq import ScheduledTask, ScheduleSource
9-
from taskiq.abc.broker import AsyncBroker
7+
from taskiq import ScheduledTask
108

9+
from taskiq_pg._internal import BasePostgresScheduleSource
1110
from taskiq_pg.asyncpg.queries import (
1211
CREATE_SCHEDULES_TABLE_QUERY,
1312
DELETE_ALL_SCHEDULES_QUERY,
@@ -19,57 +18,16 @@
1918
logger = getLogger("taskiq_pg.asyncpg_schedule_source")
2019

2120

22-
class AsyncpgScheduleSource(ScheduleSource):
21+
class AsyncpgScheduleSource(BasePostgresScheduleSource):
2322
"""Schedule source that uses asyncpg to store schedules in PostgreSQL."""
2423

2524
_database_pool: "asyncpg.Pool[asyncpg.Record]"
2625

27-
def __init__(
28-
self,
29-
broker: AsyncBroker,
30-
dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres",
31-
table_name: str = "taskiq_schedules",
32-
**connect_kwargs: tp.Any,
33-
) -> None:
34-
"""
35-
Initialize the PostgreSQL scheduler source.
36-
37-
Sets up a scheduler source that stores scheduled tasks in a PostgreSQL database.
38-
This scheduler source manages task schedules, allowing for persistent storage and retrieval of scheduled tasks
39-
across application restarts.
40-
41-
Args:
42-
dsn: PostgreSQL connection string
43-
table_name: Name of the table to store scheduled tasks. Will be created automatically if it doesn't exist.
44-
broker: The TaskIQ broker instance to use for finding and managing tasks.
45-
Required if startup_schedule is provided.
46-
**connect_kwargs: Additional keyword arguments passed to the database connection pool.
47-
48-
"""
49-
self._broker: tp.Final = broker
50-
self._dsn: tp.Final = dsn
51-
self._table_name: tp.Final = table_name
52-
self._connect_kwargs: tp.Final = connect_kwargs
53-
54-
@property
55-
def dsn(self) -> str | None:
56-
"""
57-
Get the DSN string.
58-
59-
Returns the DSN string or None if not set.
60-
"""
61-
if callable(self._dsn):
62-
return self._dsn()
63-
return self._dsn
64-
6526
async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> None:
66-
"""Update schedules in the database on startup: trancate table and insert new ones."""
27+
"""Update schedules in the database on startup: truncate table and insert new ones."""
6728
async with self._database_pool.acquire() as connection, connection.transaction():
6829
await connection.execute(DELETE_ALL_SCHEDULES_QUERY.format(self._table_name))
6930
for schedule in schedules:
70-
schedule.model_dump_json(
71-
exclude={"schedule_id", "task_name"},
72-
)
7331
await self._database_pool.execute(
7432
INSERT_SCHEDULE_QUERY.format(self._table_name),
7533
str(schedule.schedule_id),
@@ -91,6 +49,7 @@ def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]:
9149
"Schedule for task %s is not a list, skipping",
9250
task_name,
9351
)
52+
continue
9453
for schedule in task.labels["schedule"]:
9554
try:
9655
new_schedule = ScheduledTask.model_validate(

src/taskiq_pg/psqlpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from taskiq_pg.psqlpy.broker import PSQLPyBroker
22
from taskiq_pg.psqlpy.result_backend import PSQLPyResultBackend
3+
from taskiq_pg.psqlpy.schedule_source import PSQLPyScheduleSource
34

45

56
__all__ = [
67
"PSQLPyBroker",
78
"PSQLPyResultBackend",
9+
"PSQLPyScheduleSource",
810
]

src/taskiq_pg/psqlpy/queries.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,31 @@
5151
CLAIM_MESSAGE_QUERY = "UPDATE {} SET status = 'processing' WHERE id = $1 AND status = 'pending' RETURNING *"
5252

5353
DELETE_MESSAGE_QUERY = "DELETE FROM {} WHERE id = $1"
54+
55+
CREATE_SCHEDULES_TABLE_QUERY = """
56+
CREATE TABLE IF NOT EXISTS {} (
57+
id UUID PRIMARY KEY,
58+
task_name VARCHAR(100) NOT NULL,
59+
schedule JSONB NOT NULL,
60+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
61+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
62+
);
63+
"""
64+
65+
INSERT_SCHEDULE_QUERY = """
66+
INSERT INTO {} (id, task_name, schedule)
67+
VALUES ($1, $2, $3)
68+
ON CONFLICT (id) DO UPDATE
69+
SET task_name = EXCLUDED.task_name,
70+
schedule = EXCLUDED.schedule,
71+
updated_at = NOW();
72+
"""
73+
74+
SELECT_SCHEDULES_QUERY = """
75+
SELECT id, task_name, schedule
76+
FROM {};
77+
"""
78+
79+
DELETE_ALL_SCHEDULES_QUERY = """
80+
DELETE FROM {};
81+
"""

0 commit comments

Comments
 (0)