Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
backend:
- name: "Redis 6.2, redis-py <5"
redis-version: "6.2"
redis-py-version: ">=4.6,<5"
redis-py-version: ">=5,<6"
- name: "Redis 7.4, redis-py >=5"
redis-version: "7.4"
redis-py-version: ">=5"
Expand All @@ -27,13 +27,6 @@ jobs:
- name: "Memory (in-memory backend)"
redis-version: "memory"
redis-py-version: ">=5"
exclude:
# Python 3.10 + Redis 6.2 + redis-py <5 combination is skipped
- python-version: "3.10"
backend:
name: "Redis 6.2, redis-py <5"
redis-version: "6.2"
redis-py-version: ">=4.6,<5"
include:
- python-version: "3.10"
cov-threshold: 100
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ dependencies = [
"opentelemetry-api>=1.30.0",
"opentelemetry-exporter-prometheus>=0.51b0",
"prometheus-client>=0.21.1",
"py-key-value-aio[memory,redis]>=0.2.8",
"python-json-logger>=3.2.1",
"redis>=4.6",
"redis>=5",
"rich>=13.9.4",
"typer>=0.15.1",
"typing_extensions>=4.12.0",
Expand Down
2 changes: 1 addition & 1 deletion src/docket/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def create_display_layout() -> Layout:
)

# Add worker if available
if worker_name:
if worker_name: # pragma: no branch
info_lines.append(f"[bold]Worker:[/bold] {worker_name}")

# Add error if failed
Expand Down
20 changes: 20 additions & 0 deletions src/docket/docket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
overload,
)

from key_value.aio.stores.base import BaseContextManagerStore
from typing_extensions import Self

import redis.exceptions
Expand All @@ -41,6 +42,9 @@
StrikeList,
TaskFunction,
)
from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.stores.redis import RedisStore
from key_value.aio.stores.memory import MemoryStore

from .instrumentation import (
REDIS_DISRUPTIONS,
Expand Down Expand Up @@ -147,6 +151,7 @@ def __init__(
heartbeat_interval: timedelta = timedelta(seconds=2),
missed_heartbeats: int = 5,
execution_ttl: timedelta = timedelta(hours=1),
result_storage: AsyncKeyValue | None = None,
) -> None:
"""
Args:
Expand All @@ -171,6 +176,14 @@ def __init__(
self.execution_ttl = execution_ttl
self._cancel_task_script = None

self.result_storage: AsyncKeyValue
if url.startswith("memory://"):
self.result_storage = MemoryStore()
else:
self.result_storage = RedisStore(
url=url, default_collection=f"{name}:results"
)

@property
def worker_group_name(self) -> str:
return "docket-workers"
Expand Down Expand Up @@ -220,6 +233,10 @@ async def __aenter__(self) -> Self:
if "BUSYGROUP" not in repr(e):
raise

if isinstance(self.result_storage, BaseContextManagerStore):
await self.result_storage.__aenter__()
else:
await self.result_storage.setup()
return self

async def __aexit__(
Expand All @@ -228,6 +245,9 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if isinstance(self.result_storage, BaseContextManagerStore):
await self.result_storage.__aexit__(exc_type, exc_value, traceback)

del self.tasks
del self.strike_list

Expand Down
114 changes: 105 additions & 9 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import base64
import enum
import inspect
import json
Expand Down Expand Up @@ -237,7 +238,7 @@ async def sync(self) -> None:
self.message = None
self.updated_at = None

async def _delete(self) -> None:
async def delete(self) -> None:
"""Delete the progress data from Redis.

Called internally when task execution completes.
Expand Down Expand Up @@ -326,6 +327,7 @@ def __init__(
self.started_at: datetime | None = None
self.completed_at: datetime | None = None
self.error: str | None = None
self.result_key: str | None = None
self.progress: ExecutionProgress = ExecutionProgress(docket, key)
self._redis_key = f"{docket.name}:runs:{key}"

Expand Down Expand Up @@ -632,37 +634,47 @@ async def claim(self, worker: str) -> None:
}
)

async def mark_as_completed(self) -> None:
async def mark_as_completed(self, result_key: str | None = None) -> None:
"""Mark task as completed successfully.

Args:
result_key: Optional key where the task result is stored

Sets TTL on state data (from docket.execution_ttl) and deletes progress data.
"""
completed_at = datetime.now(timezone.utc).isoformat()
async with self.docket.redis() as redis:
mapping: dict[str, str] = {
"state": ExecutionState.COMPLETED.value,
"completed_at": completed_at,
}
if result_key is not None:
mapping["result_key"] = result_key
await redis.hset(
self._redis_key,
mapping={
"state": ExecutionState.COMPLETED.value,
"completed_at": completed_at,
},
mapping=mapping,
)
# Set TTL from docket configuration
await redis.expire(
self._redis_key, int(self.docket.execution_ttl.total_seconds())
)
self.state = ExecutionState.COMPLETED
self.result_key = result_key
# Delete progress data
await self.progress._delete()
await self.progress.delete()
# Publish state change event
await self._publish_state(
{"state": ExecutionState.COMPLETED.value, "completed_at": completed_at}
)

async def mark_as_failed(self, error: str | None = None) -> None:
async def mark_as_failed(
self, error: str | None = None, result_key: str | None = None
) -> None:
"""Mark task as failed.

Args:
error: Optional error message describing the failure
result_key: Optional key where the exception is stored

Sets TTL on state data (from docket.execution_ttl) and deletes progress data.
"""
Expand All @@ -674,14 +686,17 @@ async def mark_as_failed(self, error: str | None = None) -> None:
}
if error:
mapping["error"] = error
if result_key is not None:
mapping["result_key"] = result_key
await redis.hset(self._redis_key, mapping=mapping)
# Set TTL from docket configuration
await redis.expire(
self._redis_key, int(self.docket.execution_ttl.total_seconds())
)
self.state = ExecutionState.FAILED
self.result_key = result_key
# Delete progress data
await self.progress._delete()
await self.progress.delete()
# Publish state change event
state_data = {
"state": ExecutionState.FAILED.value,
Expand All @@ -691,6 +706,83 @@ async def mark_as_failed(self, error: str | None = None) -> None:
state_data["error"] = error
await self._publish_state(state_data)

async def get_result(self, *, timeout: datetime | None = None) -> Any:
"""Retrieve the result of this task execution.

If the execution is not yet complete, this method will wait using
pub/sub for state updates until completion.

Args:
timeout: Optional absolute datetime when to stop waiting.
If None, waits indefinitely.

Returns:
The result of the task execution, or None if the task returned None.

Raises:
Exception: If the task failed, raises the stored exception
TimeoutError: If timeout is reached before execution completes
"""
import asyncio
from datetime import datetime, timezone

# Wait for execution to complete if not already done
if self.state not in (ExecutionState.COMPLETED, ExecutionState.FAILED):
# Calculate timeout duration if absolute timeout provided
timeout_seconds = None
if timeout is not None:
timeout_seconds = (timeout - datetime.now(timezone.utc)).total_seconds()
if timeout_seconds <= 0:
raise TimeoutError(
f"Timeout waiting for execution {self.key} to complete"
)

try:

async def wait_for_completion():
async for event in self.subscribe(): # pragma: no cover
if event["type"] == "state":
state = ExecutionState(event["state"])
if state in (
ExecutionState.COMPLETED,
ExecutionState.FAILED,
):
# Sync to get latest data including result key
await self.sync()
break

# Use asyncio.wait_for to enforce timeout
await asyncio.wait_for(wait_for_completion(), timeout=timeout_seconds)
except asyncio.TimeoutError:
raise TimeoutError(
f"Timeout waiting for execution {self.key} to complete"
)

# If failed, retrieve and raise the exception
if self.state == ExecutionState.FAILED:
if self.result_key:
# Retrieve serialized exception from result_storage
result_data = await self.docket.result_storage.get(self.result_key)
if result_data and "data" in result_data:
# Base64-decode and unpickle
pickled_exception = base64.b64decode(result_data["data"])
exception = cloudpickle.loads(pickled_exception) # type: ignore[arg-type]
raise exception
# If no stored exception, raise a generic error with the error message
error_msg = self.error or "Task execution failed"
raise Exception(error_msg)

# If completed successfully, retrieve result if available
if self.result_key:
result_data = await self.docket.result_storage.get(self.result_key)
if result_data is not None and "data" in result_data:
# Base64-decode and unpickle
pickled_result = base64.b64decode(result_data["data"])
return cloudpickle.loads(pickled_result) # type: ignore[arg-type]

# No result stored - task returned None
return None

async def sync(self) -> None:
"""Synchronize instance attributes with current execution data from Redis.

Expand Down Expand Up @@ -720,13 +812,17 @@ async def sync(self) -> None:
else None
)
self.error = data[b"error"].decode() if b"error" in data else None
self.result_key = (
data[b"result_key"].decode() if b"result_key" in data else None
)
else:
# No data exists - reset to defaults
self.state = ExecutionState.SCHEDULED
self.worker = None
self.started_at = None
self.completed_at = None
self.error = None
self.result_key = None

# Sync progress data
await self.progress.sync()
Expand Down
Loading