Skip to content

Commit ccc1cc7

Browse files
authored
Supports configuring and running a worker from the CLI (#22)
* Supports configuring and running a worker from the CLI This ended up making me confront the fact that we didn't really deal with how the worker knows which tasks to run. From the client side it's really easy, since you can just put something blindly on the docket and we'll register it for you if it's a function. From the worker side, there's nowhere to really start from. I don't think this is the final form of this, but it's a baby step. This version supports a string path pointing to any iterable of functions you set up, so you can have: ``` my_tasks = [task_a, task_b, task_c] ``` and then register them with `--tasks my.module:my_tasks`. We should really think a lot more about this. Closes #8 * Updating the chaos tests * Use the right redis * Making the CLI utility a code utility
1 parent 8f9772f commit ccc1cc7

File tree

15 files changed

+517
-108
lines changed

15 files changed

+517
-108
lines changed

chaos/driver.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,15 @@ async def main(
4949
workers: int = 4,
5050
):
5151
with RedisContainer("redis:7.4.2") as redis_server:
52+
redis_url = f"redis://{redis_server.get_container_host_ip()}:{redis_server.get_exposed_port(6379)}"
5253
docket = Docket(
5354
name=f"test-docket-{uuid4()}",
54-
host=redis_server.get_container_host_ip(),
55-
port=redis_server.get_exposed_port(6379),
56-
db=0,
55+
url=redis_url,
5756
)
5857
environment = {
5958
**os.environ,
60-
"CHAOS_DOCKET_NAME": docket.name,
61-
"CHAOS_REDIS_HOST": docket.host,
62-
"CHAOS_REDIS_PORT": str(docket.port),
63-
"CHAOS_REDIS_DB": str(docket.db),
59+
"DOCKET_NAME": docket.name,
60+
"DOCKET_URL": redis_url,
6461
}
6562

6663
if tasks % producers != 0:
@@ -93,8 +90,19 @@ async def spawn_worker() -> Process:
9390
return await asyncio.create_subprocess_exec(
9491
*python_entrypoint(),
9592
"-m",
96-
"chaos.worker",
97-
env=environment | {"OTEL_SERVICE_NAME": "chaos-worker"},
93+
"docket",
94+
"worker",
95+
"--docket",
96+
docket.name,
97+
"--url",
98+
redis_url,
99+
"--tasks",
100+
"chaos.tasks:chaos_tasks",
101+
env=environment
102+
| {
103+
"OTEL_SERVICE_NAME": "chaos-worker",
104+
"DOCKET_WORKER_REDELIVERY_TIMEOUT": "5s",
105+
},
98106
stdout=subprocess.DEVNULL,
99107
stderr=subprocess.DEVNULL,
100108
)

chaos/producer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616

1717
async def main(tasks_to_produce: int):
1818
docket = Docket(
19-
name=os.environ["CHAOS_DOCKET_NAME"],
20-
host=os.environ["CHAOS_REDIS_HOST"],
21-
port=int(os.environ["CHAOS_REDIS_PORT"]),
22-
db=int(os.environ["CHAOS_REDIS_DB"]),
19+
name=os.environ["DOCKET_NAME"],
20+
url=os.environ["DOCKET_URL"],
2321
)
2422
tasks_sent = 0
2523
while tasks_sent < tasks_to_produce:

chaos/tasks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ async def hello(
2020

2121
async def toxic():
2222
sys.exit(42)
23+
24+
25+
chaos_tasks = [hello, toxic]

chaos/worker.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

src/docket/cli.py

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
import asyncio
2+
import enum
3+
import logging
4+
import socket
5+
from datetime import timedelta
6+
from typing import Annotated
7+
18
import typer
29

3-
from docket import __version__
10+
from . import __version__, tasks
11+
from .docket import Docket
12+
from .worker import Worker
413

514
app: typer.Typer = typer.Typer(
615
help="Docket - A distributed background task system for Python functions",
@@ -9,11 +18,172 @@
918
)
1019

1120

21+
class LogLevel(enum.StrEnum):
22+
DEBUG = "DEBUG"
23+
INFO = "INFO"
24+
WARNING = "WARNING"
25+
ERROR = "ERROR"
26+
CRITICAL = "CRITICAL"
27+
28+
29+
def duration(duration_str: str | timedelta) -> timedelta:
30+
"""
31+
Parse a duration string into a timedelta.
32+
33+
Supported formats:
34+
- 123 = 123 seconds
35+
- 123s = 123 seconds
36+
- 123m = 123 minutes
37+
- 123h = 123 hours
38+
- 00:00 = mm:ss
39+
- 00:00:00 = hh:mm:ss
40+
"""
41+
if isinstance(duration_str, timedelta):
42+
return duration_str
43+
44+
if ":" in duration_str:
45+
parts = duration_str.split(":")
46+
if len(parts) == 2: # mm:ss
47+
minutes, seconds = map(int, parts)
48+
return timedelta(minutes=minutes, seconds=seconds)
49+
elif len(parts) == 3: # hh:mm:ss
50+
hours, minutes, seconds = map(int, parts)
51+
return timedelta(hours=hours, minutes=minutes, seconds=seconds)
52+
else:
53+
raise ValueError(f"Invalid duration string: {duration_str}")
54+
elif duration_str.endswith("s"):
55+
return timedelta(seconds=int(duration_str[:-1]))
56+
elif duration_str.endswith("m"):
57+
return timedelta(minutes=int(duration_str[:-1]))
58+
elif duration_str.endswith("h"):
59+
return timedelta(hours=int(duration_str[:-1]))
60+
else:
61+
return timedelta(seconds=int(duration_str))
62+
63+
1264
@app.command(
1365
help="Start a worker to process tasks",
1466
)
15-
def worker() -> None:
16-
print("TODO: Configure and start a worker")
67+
def worker(
68+
tasks: Annotated[
69+
list[str],
70+
typer.Option(
71+
"--tasks",
72+
help=(
73+
"The dotted path of a task collection to register with the docket. "
74+
"This can be specified multiple times. A task collection is any "
75+
"iterable of async functions."
76+
),
77+
),
78+
] = ["docket.tasks:standard_tasks"],
79+
docket_: Annotated[
80+
str,
81+
typer.Option(
82+
"--docket",
83+
help="The name of the docket",
84+
envvar="DOCKET_NAME",
85+
),
86+
] = "docket",
87+
url: Annotated[
88+
str,
89+
typer.Option(
90+
help="The URL of the Redis server",
91+
envvar="DOCKET_URL",
92+
),
93+
] = "redis://localhost:6379/0",
94+
name: Annotated[
95+
str | None,
96+
typer.Option(
97+
help="The name of the worker",
98+
envvar="DOCKET_WORKER_NAME",
99+
),
100+
] = socket.gethostname(),
101+
logging_level: Annotated[
102+
LogLevel,
103+
typer.Option(
104+
help="The logging level",
105+
envvar="DOCKET_LOGGING_LEVEL",
106+
),
107+
] = LogLevel.INFO,
108+
prefetch_count: Annotated[
109+
int,
110+
typer.Option(
111+
help="The number of tasks to request from the docket at a time",
112+
envvar="DOCKET_WORKER_PREFETCH_COUNT",
113+
),
114+
] = 10,
115+
redelivery_timeout: Annotated[
116+
timedelta,
117+
typer.Option(
118+
parser=duration,
119+
help="How long to wait before redelivering a task to another worker",
120+
envvar="DOCKET_WORKER_REDELIVERY_TIMEOUT",
121+
),
122+
] = timedelta(minutes=5),
123+
reconnection_delay: Annotated[
124+
timedelta,
125+
typer.Option(
126+
parser=duration,
127+
help=(
128+
"How long to wait before reconnecting to the Redis server after "
129+
"a connection error"
130+
),
131+
envvar="DOCKET_WORKER_RECONNECTION_DELAY",
132+
),
133+
] = timedelta(seconds=5),
134+
until_finished: Annotated[
135+
bool,
136+
typer.Option(
137+
"--until-finished",
138+
help="Exit after the current docket is finished",
139+
),
140+
] = False,
141+
) -> None:
142+
logging.basicConfig(level=logging_level)
143+
asyncio.run(
144+
Worker.run(
145+
docket_name=docket_,
146+
url=url,
147+
name=name,
148+
prefetch_count=prefetch_count,
149+
redelivery_timeout=redelivery_timeout,
150+
reconnection_delay=reconnection_delay,
151+
until_finished=until_finished,
152+
tasks=tasks,
153+
)
154+
)
155+
156+
157+
@app.command(help="Adds a trace task to the Docket")
158+
def trace(
159+
docket_: Annotated[
160+
str,
161+
typer.Option(
162+
"--docket",
163+
help="The name of the docket",
164+
envvar="DOCKET_NAME",
165+
),
166+
] = "docket",
167+
url: Annotated[
168+
str,
169+
typer.Option(
170+
help="The URL of the Redis server",
171+
envvar="DOCKET_URL",
172+
),
173+
] = "redis://localhost:6379/0",
174+
message: Annotated[
175+
str,
176+
typer.Argument(
177+
help="The message to print",
178+
),
179+
] = "Howdy!",
180+
) -> None:
181+
async def run() -> None:
182+
async with Docket(name=docket_, url=url) as docket:
183+
execution = await docket.add(tasks.trace)(message)
184+
print(f"Added trace task {execution.key!r} to the docket {docket.name!r}")
185+
186+
asyncio.run(run())
17187

18188

19189
@app.command(

src/docket/docket.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
from contextlib import asynccontextmanager
23
from datetime import datetime, timezone
34
from types import TracebackType
@@ -6,6 +7,7 @@
67
AsyncGenerator,
78
Awaitable,
89
Callable,
10+
Iterable,
911
ParamSpec,
1012
Self,
1113
TypeVar,
@@ -31,26 +33,35 @@
3133
P = ParamSpec("P")
3234
R = TypeVar("R")
3335

36+
TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
37+
3438

3539
class Docket:
3640
tasks: dict[str, Callable[..., Awaitable[Any]]]
3741

3842
def __init__(
3943
self,
4044
name: str = "docket",
41-
host: str = "localhost",
42-
port: int = 6379,
43-
db: int = 0,
44-
password: str | None = None,
45+
url: str = "redis://localhost:6379/0",
4546
) -> None:
47+
"""
48+
Args:
49+
name: The name of the docket.
50+
url: The URL of the Redis server. For example:
51+
- "redis://localhost:6379/0"
52+
- "redis://user:password@localhost:6379/0"
53+
- "redis://user:password@localhost:6379/0?ssl=true"
54+
- "rediss://localhost:6379/0"
55+
- "unix:///path/to/redis.sock"
56+
"""
4657
self.name = name
47-
self.host = host
48-
self.port = port
49-
self.db = db
50-
self.password = password
58+
self.url = url
5159

5260
async def __aenter__(self) -> Self:
53-
self.tasks = {}
61+
from .tasks import standard_tasks
62+
63+
self.tasks = {fn.__name__: fn for fn in standard_tasks}
64+
5465
return self
5566

5667
async def __aexit__(
@@ -63,12 +74,7 @@ async def __aexit__(
6374

6475
@asynccontextmanager
6576
async def redis(self) -> AsyncGenerator[Redis, None]:
66-
async with Redis(
67-
host=self.host,
68-
port=self.port,
69-
db=self.db,
70-
password=self.password,
71-
) as redis:
77+
async with Redis.from_url(self.url) as redis:
7278
yield redis
7379

7480
def register(self, function: Callable[..., Awaitable[Any]]) -> None:
@@ -78,6 +84,19 @@ def register(self, function: Callable[..., Awaitable[Any]]) -> None:
7884

7985
self.tasks[function.__name__] = function
8086

87+
def register_collection(self, collection_path: str) -> None:
88+
"""
89+
Register a collection of tasks.
90+
91+
Args:
92+
collection_path: A path in the format "module:collection".
93+
"""
94+
module_name, _, member_name = collection_path.rpartition(":")
95+
module = importlib.import_module(module_name)
96+
collection = getattr(module, member_name)
97+
for function in collection:
98+
self.register(function)
99+
81100
@overload
82101
def add(
83102
self,

0 commit comments

Comments
 (0)