-
Notifications
You must be signed in to change notification settings - Fork 2
fix: ensure that only one worker processing task #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| --- | ||
| title: Contributing | ||
| --- |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| --- | ||
| title: Broker | ||
| --- | ||
|
|
||
| To use broker with PostgreSQL you need to import broker and result backend from this library and provide a address for connection. For example, lets create a file `broker.py` with the following content: | ||
|
|
||
| === "asyncpg" | ||
|
|
||
| ```python | ||
| import asyncio | ||
| from taskiq_pg.asyncpg import AsyncpgResultBackend, AsyncpgBroker | ||
|
|
||
|
|
||
| dsn = "postgres://postgres:postgres@localhost:5432/postgres" | ||
| broker = AsyncpgBroker(dsn).with_result_backend(AsyncpgResultBackend(dsn)) | ||
|
|
||
|
|
||
| @broker.task | ||
| async def best_task_ever() -> None: | ||
| """Solve all problems in the world.""" | ||
| await asyncio.sleep(5.5) | ||
| print("All problems are solved!") | ||
|
|
||
|
|
||
| async def main(): | ||
| await broker.startup() | ||
| task = await best_task_ever.kiq() | ||
| print(await task.wait_result()) | ||
| await broker.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
| ``` | ||
|
|
||
| === "psqlpy" | ||
|
|
||
| ```python | ||
| import asyncio | ||
| from taskiq_pg.psqlpy import PSQLPyResultBackend, PSQLPyBroker | ||
|
|
||
|
|
||
| dsn = "postgres://postgres:postgres@localhost:5432/postgres" | ||
| broker = PSQLPyBroker(dsn).with_result_backend(PSQLPyResultBackend(dsn)) | ||
|
|
||
|
|
||
| @broker.task | ||
| async def best_task_ever() -> None: | ||
| """Solve all problems in the world.""" | ||
| await asyncio.sleep(5.5) | ||
| print("All problems are solved!") | ||
|
|
||
|
|
||
| async def main(): | ||
| await broker.startup() | ||
| task = await best_task_ever.kiq() | ||
| print(await task.wait_result()) | ||
| await broker.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
| ``` | ||
|
|
||
| Then you can run this file with: | ||
|
|
||
| ```bash | ||
| python broker.py | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| --- | ||
| title: Getting Started | ||
| --- |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| --- | ||
| title: Schedule Source | ||
| --- |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| """ | ||
| How to run: | ||
|
|
||
| 1) Run worker in one terminal: | ||
| uv run taskiq worker examples.example_with_broker:broker | ||
|
|
||
| 2) Run this script in another terminal: | ||
| uv run python -m examples.example_with_broker | ||
| """ | ||
|
|
||
| import asyncio | ||
|
|
||
| from taskiq_pg.asyncpg import AsyncpgBroker, AsyncpgResultBackend | ||
|
|
||
|
|
||
| dsn = "postgres://taskiq_postgres:look_in_vault@localhost:5432/taskiq_postgres" | ||
| broker = AsyncpgBroker(dsn).with_result_backend(AsyncpgResultBackend(dsn)) | ||
|
|
||
|
|
||
| @broker.task("solve_all_problems") | ||
| async def best_task_ever() -> None: | ||
| """Solve all problems in the world.""" | ||
| await asyncio.sleep(2) | ||
| print("All problems are solved!") | ||
|
|
||
|
|
||
| async def main(): | ||
| await broker.startup() | ||
| task = await best_task_ever.kiq() | ||
| print(await task.wait_result()) | ||
| await broker.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,15 +6,16 @@ | |||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||
|
|
||||||||||||||||||
| import psqlpy | ||||||||||||||||||
| from psqlpy.exceptions import ConnectionExecuteError | ||||||||||||||||||
| from psqlpy.extra_types import JSONB | ||||||||||||||||||
| from taskiq import AckableMessage, BrokerMessage | ||||||||||||||||||
|
|
||||||||||||||||||
| from taskiq_pg._internal.broker import BasePostgresBroker | ||||||||||||||||||
| from taskiq_pg.psqlpy.queries import ( | ||||||||||||||||||
| CLAIM_MESSAGE_QUERY, | ||||||||||||||||||
| CREATE_MESSAGE_TABLE_QUERY, | ||||||||||||||||||
| DELETE_MESSAGE_QUERY, | ||||||||||||||||||
| INSERT_MESSAGE_QUERY, | ||||||||||||||||||
| SELECT_MESSAGE_QUERY, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -35,6 +36,7 @@ class MessageRow: | |||||||||||||||||
| task_name: str | ||||||||||||||||||
| message: str | ||||||||||||||||||
| labels: JSONB | ||||||||||||||||||
| status: str | ||||||||||||||||||
| created_at: datetime | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -165,14 +167,17 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]: | |||||||||||||||||
| try: | ||||||||||||||||||
| payload = await self._queue.get() | ||||||||||||||||||
| message_id = int(payload) # payload is the message id | ||||||||||||||||||
| message_row = await self.read_conn.fetch_row( | ||||||||||||||||||
| SELECT_MESSAGE_QUERY.format(self.table_name), | ||||||||||||||||||
| [message_id], | ||||||||||||||||||
| ) | ||||||||||||||||||
| # ugly type hacks b/c SingleQueryResult.as_class return type is wrong | ||||||||||||||||||
| try: | ||||||||||||||||||
| async with self.write_pool.acquire() as conn: | ||||||||||||||||||
| claimed_message = await conn.fetch_row( | ||||||||||||||||||
| CLAIM_MESSAGE_QUERY.format(self.table_name), | ||||||||||||||||||
| [message_id], | ||||||||||||||||||
| ) | ||||||||||||||||||
| except ConnectionExecuteError: # message was claimed by another worker | ||||||||||||||||||
|
||||||||||||||||||
| except ConnectionExecuteError: # message was claimed by another worker | |
| except ConnectionExecuteError as exc: # message was claimed by another worker or other connection issue | |
| # Check if the error is due to a claim conflict (e.g., unique violation) | |
| # Adjust the condition below to match your DB's claim conflict error code/message | |
| if hasattr(exc, "pgcode") and exc.pgcode == "23505": # unique_violation | |
| # Message was claimed by another worker | |
| continue | |
| logger.exception("Database connection error while claiming message") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import typing as tp | ||
| import uuid | ||
| from contextlib import suppress | ||
|
|
||
| import asyncpg | ||
| import pytest | ||
| from taskiq import BrokerMessage | ||
|
|
||
| from taskiq_pg.asyncpg import AsyncpgBroker | ||
| from taskiq_pg.psqlpy import PSQLPyBroker | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| @pytest.mark.parametrize( | ||
| "broker_class", | ||
| [ | ||
| AsyncpgBroker, | ||
| PSQLPyBroker, | ||
| ], | ||
| ) | ||
| async def test_when_two_workers_listen__then_single_message_processed_once( | ||
| pg_dsn: str, | ||
| broker_class: type[AsyncpgBroker | PSQLPyBroker], | ||
| ) -> None: | ||
| # Given: уникальные имена таблицы и канала, два брокера, одна задача | ||
|
||
| table_name: str = f"taskiq_messages_{uuid.uuid4().hex}" | ||
| channel_name: str = f"taskiq_channel_{uuid.uuid4().hex}" | ||
| task_id: str = uuid.uuid4().hex | ||
|
|
||
| broker1 = broker_class(dsn=pg_dsn, table_name=table_name, channel_name=channel_name) | ||
| broker2 = broker_class(dsn=pg_dsn, table_name=table_name, channel_name=channel_name) | ||
|
|
||
| # Подключение для проверок состояния в таблице | ||
| conn: asyncpg.Connection = await asyncpg.connect(dsn=pg_dsn) | ||
|
|
||
| # Сообщение для публикации | ||
|
Comment on lines
+36
to
+39
|
||
| message: BrokerMessage = BrokerMessage( | ||
| task_id=task_id, | ||
| task_name="example:best_task_ever", | ||
| message=b'{"hello":"world"}', | ||
| labels={}, | ||
| ) | ||
|
|
||
| # When: стартуем брокеры и два слушателя, публикуем одно сообщение | ||
|
||
| await broker1.startup() | ||
| await broker2.startup() | ||
|
|
||
| agen1 = broker1.listen() | ||
| agen2 = broker2.listen() | ||
|
|
||
| # Запускаем ожидание первого сообщения у обоих слушателей до публикации, | ||
| # чтобы оба гарантированно получили NOTIFY. | ||
|
Comment on lines
+54
to
+55
|
||
| t1: asyncio.Task = asyncio.create_task(agen1.__anext__()) | ||
| t2: asyncio.Task = asyncio.create_task(agen2.__anext__()) | ||
|
|
||
| try: | ||
| await broker1.kick(message) | ||
|
|
||
| done, _ = await asyncio.wait( | ||
| {t1, t2}, | ||
| timeout=5.0, | ||
| return_when=asyncio.FIRST_COMPLETED, | ||
| ) | ||
|
|
||
| # Then: только один слушатель получает сообщение | ||
|
||
| assert len(done) == 1, "Ровно один воркер должен получить сообщение" | ||
| winner_task: asyncio.Task = next(iter(done)) | ||
| ack_message = tp.cast("tp.Any", winner_task.result()) | ||
|
|
||
| # До подтверждения проверяем, что статус в таблице = 'processing' | ||
|
||
| row = await conn.fetchrow( | ||
| f"SELECT id, status FROM {table_name} WHERE task_id = $1", | ||
| task_id, | ||
| ) | ||
| assert row is not None, "Сообщение должно существовать в таблице" | ||
| assert row["status"] == "processing", "Сообщение должно быть помечено как processing после claim" | ||
|
|
||
| # Подтверждаем обработку победившим воркером | ||
| await ack_message.ack() | ||
|
|
||
| # И проверяем, что запись удалена | ||
| cnt: int = tp.cast( | ||
| "int", | ||
| await conn.fetchval( | ||
| f"SELECT COUNT(*) FROM {table_name} WHERE task_id = $1", | ||
| task_id, | ||
| ), | ||
| ) | ||
| assert cnt == 0, "Запись должна быть удалена после ack" | ||
| finally: | ||
| with suppress(Exception): | ||
| await broker1.shutdown() | ||
| await broker2.shutdown() | ||
|
|
||
| try: | ||
| await conn.execute(f"DROP TABLE IF EXISTS {table_name}") | ||
| finally: | ||
| await conn.close() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CLAIM_MESSAGE_QUERY only returns 'id' and 'message' fields, but the psqlpy version returns all fields (*). This inconsistency could lead to maintenance issues if additional fields need to be accessed in the future."