Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
},
"PredictionRequest": {
"properties": {
"context": {
"title": "Context",
"type": "object",
"additionalProperties": {
"type": "string"
}
},
"created_at": {
"format": "date-time",
"title": "Created At",
Expand Down
1 change: 1 addition & 0 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Config:
class PredictionRequest(PredictionBaseModel):
id: Optional[str] = None
created_at: Optional[datetime] = None
context: Optional[Dict[str, str]] = None

# TODO: deprecate this
output_file_prefix: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Cancel:
@define
class PredictionInput:
payload: Dict[str, Any]
context: Dict[str, str] = {}


@define
Expand Down
2 changes: 1 addition & 1 deletion python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def predict(
payload = prediction.input.copy()

sid = self._worker.subscribe(task.handle_event, tag=tag)
task.track(self._worker.predict(payload, tag=tag))
task.track(self._worker.predict(payload, context=prediction.context, tag=tag))
task.add_done_callback(self._task_done_callback(tag, sid))

return task
Expand Down
3 changes: 2 additions & 1 deletion python/cog/server/scope.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Dict, Generator, Optional, Union

from attrs import evolve, frozen

Expand All @@ -11,6 +11,7 @@
@frozen
class Scope:
record_metric: Callable[[str, Union[float, int]], None]
context: Dict[str, str] = {}
_tag: Optional[str] = None


Expand Down
46 changes: 38 additions & 8 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def setup(self) -> "Future[Done]":
return self._setup_result

def predict(
self, payload: Dict[str, Any], tag: Optional[str] = None
self,
payload: Dict[str, Any],
tag: Optional[str] = None,
*,
context: Optional[Dict[str, str]] = None,
) -> "Future[Done]":
# TODO: tag is Optional, but it's required when in concurrent mode and
# basically unnecessary in sequential mode. Should we have a separate
Expand All @@ -157,11 +161,17 @@ def predict(
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)

self._prediction_start_pool.submit(self._start_prediction(tag, payload))
self._prediction_start_pool.submit(
self._start_prediction(tag, payload, context=context)
)
return result

def _start_prediction(
self, tag: Optional[str], payload: Dict[str, Any]
self,
tag: Optional[str],
payload: Dict[str, Any],
*,
context: Optional[Dict[str, str]],
) -> Callable[[], None]:
def start_prediction() -> None:
try:
Expand Down Expand Up @@ -212,7 +222,7 @@ def start_prediction() -> None:
# send the prediction to the child to start
self._events.send(
Envelope(
event=PredictionInput(payload=payload),
event=PredictionInput(payload=payload, context=context or {}),
tag=tag,
)
)
Expand Down Expand Up @@ -563,7 +573,13 @@ def _loop(
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
self._predict(e.tag, e.event.payload, predict, redirector)
self._predict(
e.tag,
e.event.payload,
context=e.event.context,
predict=predict,
redirector=redirector,
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)

Expand Down Expand Up @@ -597,7 +613,13 @@ async def _aloop(
break
elif isinstance(e.event, PredictionInput):
tasks[e.tag] = tg.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
self._apredict(
e.tag,
e.event.payload,
context=e.event.context,
predict=predict,
redirector=redirector,
)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
Expand All @@ -606,10 +628,14 @@ def _predict(
self,
tag: Optional[str],
payload: Dict[str, Any],
*,
context: Dict[str, str],
predict: Callable[..., Any],
redirector: StreamRedirector,
) -> None:
with self._handle_predict_error(redirector, tag=tag):
with evolve_scope(context=context), self._handle_predict_error(
redirector, tag=tag
):
result = predict(**payload)

if result:
Expand Down Expand Up @@ -657,10 +683,14 @@ async def _apredict(
self,
tag: Optional[str],
payload: Dict[str, Any],
*,
context: Dict[str, str],
predict: Callable[..., Any],
redirector: SimpleStreamRedirector,
) -> None:
with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag):
with evolve_scope(context=context, tag=tag), self._handle_predict_error(
redirector, tag=tag
):
future_result = predict(**payload)

if future_result:
Expand Down
7 changes: 7 additions & 0 deletions python/tests/server/fixtures/with_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import current_scope, Input

class Predictor:
def predict(self, name: str = Input()):
prefix = current_scope().context["prefix"]
return f"{prefix} {name}!"

7 changes: 7 additions & 0 deletions python/tests/server/fixtures/with_context_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import current_scope, Input

class Predictor:
async def predict(self, name: str = Input()):
prefix = current_scope().context["prefix"]
return f"{prefix} {name}!"

11 changes: 10 additions & 1 deletion python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
from concurrent.futures import Future
from datetime import datetime
from typing import Any, Dict, Optional
from unittest import mock

import pytest
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self):
self._setup_future = None
self._predict_futures = {}
self.last_prediction_payload = None
self.last_prediction_context = None

def subscribe(self, subscriber, tag=None):
sid = uuid.uuid4()
Expand Down Expand Up @@ -71,9 +73,16 @@ def run_setup(self, events):
if isinstance(event, Done):
self._setup_future.set_result(event)

def predict(self, payload, tag=None):
def predict(
self,
payload: Dict[str, Any],
tag: Optional[str] = None,
*,
context: Optional[Dict[str, str]] = None,
):
assert tag not in self._predict_futures or self._predict_futures[tag].done()
self.last_prediction_payload = payload
self.last_prediction_context = context
self._predict_futures[tag] = Future()
print(f"setting {tag}, now {self._predict_futures}")
return self._predict_futures[tag]
Expand Down
28 changes: 28 additions & 0 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,34 @@ def test_async_setup_uses_same_loop_as_predict(worker: Worker):
assert result, "Expected worker to return True to assert same event loop"


@uses_worker("with_context")
def test_context(worker: Worker):
result = _process(
worker,
lambda: worker.predict({"name": "context"}, context={"prefix": "hello"}),
tag=None,
)
assert result.done
assert not result.done.error
assert result.output == "hello context!"


@uses_worker("with_context_async", min_python=(3, 11), is_async=True)
def test_context_async(worker: Worker):
result = _process(
worker,
lambda: worker.predict(
{"name": "context"}, tag="t1", context={"prefix": "hello"}
),
tag=None,
)

print("\n".join(result.stderr_lines))
assert result.done
assert not result.done.error, result.done.error_detail
assert result.output == "hello context!"


@frozen
class SetupState:
fut: "Future[Done]"
Expand Down