diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 018c5bd6af..9afb03e391 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -28,6 +28,13 @@ }, "PredictionRequest": { "properties": { + "context": { + "title": "Context", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, "created_at": { "format": "date-time", "title": "Created At", diff --git a/python/cog/schema.py b/python/cog/schema.py index c8624da8bf..e2b64931e3 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -68,6 +68,7 @@ class Config: class PredictionRequest(PredictionBaseModel): id: Optional[str] = None created_at: Optional[datetime] = None + context: Optional[Dict[str, str]] = None # TODO: deprecate this output_file_prefix: Optional[str] = None diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py index 93c7384142..da7925ea17 100644 --- a/python/cog/server/eventtypes.py +++ b/python/cog/server/eventtypes.py @@ -13,6 +13,7 @@ class Cancel: @define class PredictionInput: payload: Dict[str, Any] + context: Dict[str, str] = {} @define diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index a911a705ef..41e59ce86b 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -113,7 +113,7 @@ def predict( payload = prediction.input.copy() sid = self._worker.subscribe(task.handle_event, tag=tag) - task.track(self._worker.predict(payload, tag=tag)) + task.track(self._worker.predict(payload, context=prediction.context, tag=tag)) task.add_done_callback(self._task_done_callback(tag, sid)) return task diff --git a/python/cog/server/scope.py b/python/cog/server/scope.py index 161213f9fe..54a1f8b014 100644 --- a/python/cog/server/scope.py +++ b/python/cog/server/scope.py @@ -1,7 +1,7 @@ import warnings from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Callable, Generator, Optional, Union +from typing import Any, Callable, Dict, Generator, Optional, Union from attrs import evolve, frozen @@ -11,6 +11,7 @@ @frozen class Scope: record_metric: Callable[[str, Union[float, int]], None] + context: Dict[str, str] = {} _tag: Optional[str] = None diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index adcd0658da..78de3edec7 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -134,7 +134,11 @@ def setup(self) -> "Future[Done]": return self._setup_result def predict( - self, payload: Dict[str, Any], tag: Optional[str] = None + self, + payload: Dict[str, Any], + tag: Optional[str] = None, + *, + context: Optional[Dict[str, str]] = None, ) -> "Future[Done]": # TODO: tag is Optional, but it's required when in concurrent mode and # basically unnecessary in sequential mode. Should we have a separate @@ -157,11 +161,17 @@ def predict( result = Future() self._predictions_in_flight[tag] = PredictionState(tag, payload, result) - self._prediction_start_pool.submit(self._start_prediction(tag, payload)) + self._prediction_start_pool.submit( + self._start_prediction(tag, payload, context=context) + ) return result def _start_prediction( - self, tag: Optional[str], payload: Dict[str, Any] + self, + tag: Optional[str], + payload: Dict[str, Any], + *, + context: Optional[Dict[str, str]], ) -> Callable[[], None]: def start_prediction() -> None: try: @@ -212,7 +222,7 @@ def start_prediction() -> None: # send the prediction to the child to start self._events.send( Envelope( - event=PredictionInput(payload=payload), + event=PredictionInput(payload=payload, context=context or {}), tag=tag, ) ) @@ -563,7 +573,13 @@ def _loop( elif isinstance(e.event, Shutdown): break elif isinstance(e.event, PredictionInput): - self._predict(e.tag, e.event.payload, predict, redirector) + self._predict( + e.tag, + e.event.payload, + context=e.event.context, + predict=predict, + redirector=redirector, + ) else: print(f"Got unexpected event: {e.event}", file=sys.stderr) @@ -597,7 +613,13 @@ async def _aloop( break elif isinstance(e.event, PredictionInput): tasks[e.tag] = tg.create_task( - self._apredict(e.tag, e.event.payload, predict, redirector) + self._apredict( + e.tag, + e.event.payload, + context=e.event.context, + predict=predict, + redirector=redirector, + ) ) else: print(f"Got unexpected event: {e.event}", file=sys.stderr) @@ -606,10 +628,14 @@ def _predict( self, tag: Optional[str], payload: Dict[str, Any], + *, + context: Dict[str, str], predict: Callable[..., Any], redirector: StreamRedirector, ) -> None: - with self._handle_predict_error(redirector, tag=tag): + with evolve_scope(context=context), self._handle_predict_error( + redirector, tag=tag + ): result = predict(**payload) if result: @@ -657,10 +683,14 @@ async def _apredict( self, tag: Optional[str], payload: Dict[str, Any], + *, + context: Dict[str, str], predict: Callable[..., Any], redirector: SimpleStreamRedirector, ) -> None: - with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag): + with evolve_scope(context=context, tag=tag), self._handle_predict_error( + redirector, tag=tag + ): future_result = predict(**payload) if future_result: diff --git a/python/tests/server/fixtures/with_context.py b/python/tests/server/fixtures/with_context.py new file mode 100644 index 0000000000..dd439a42f4 --- /dev/null +++ b/python/tests/server/fixtures/with_context.py @@ -0,0 +1,7 @@ +from cog import current_scope, Input + +class Predictor: + def predict(self, name: str = Input()): + prefix = current_scope().context["prefix"] + return f"{prefix} {name}!" + diff --git a/python/tests/server/fixtures/with_context_async.py b/python/tests/server/fixtures/with_context_async.py new file mode 100644 index 0000000000..eae74b7359 --- /dev/null +++ b/python/tests/server/fixtures/with_context_async.py @@ -0,0 +1,7 @@ +from cog import current_scope, Input + +class Predictor: + async def predict(self, name: str = Input()): + prefix = current_scope().context["prefix"] + return f"{prefix} {name}!" + diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 5e27357213..a001dd9eb9 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -2,6 +2,7 @@ import uuid from concurrent.futures import Future from datetime import datetime +from typing import Any, Dict, Optional from unittest import mock import pytest @@ -43,6 +44,7 @@ def __init__(self): self._setup_future = None self._predict_futures = {} self.last_prediction_payload = None + self.last_prediction_context = None def subscribe(self, subscriber, tag=None): sid = uuid.uuid4() @@ -71,9 +73,16 @@ def run_setup(self, events): if isinstance(event, Done): self._setup_future.set_result(event) - def predict(self, payload, tag=None): + def predict( + self, + payload: Dict[str, Any], + tag: Optional[str] = None, + *, + context: Optional[Dict[str, str]] = None, + ): assert tag not in self._predict_futures or self._predict_futures[tag].done() self.last_prediction_payload = payload + self.last_prediction_context = context self._predict_futures[tag] = Future() print(f"setting {tag}, now {self._predict_futures}") return self._predict_futures[tag] diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 2b31d93d2e..ca5ce910b9 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -692,6 +692,34 @@ def test_async_setup_uses_same_loop_as_predict(worker: Worker): assert result, "Expected worker to return True to assert same event loop" +@uses_worker("with_context") +def test_context(worker: Worker): + result = _process( + worker, + lambda: worker.predict({"name": "context"}, context={"prefix": "hello"}), + tag=None, + ) + assert result.done + assert not result.done.error + assert result.output == "hello context!" + + +@uses_worker("with_context_async", min_python=(3, 11), is_async=True) +def test_context_async(worker: Worker): + result = _process( + worker, + lambda: worker.predict( + {"name": "context"}, tag="t1", context={"prefix": "hello"} + ), + tag=None, + ) + + print("\n".join(result.stderr_lines)) + assert result.done + assert not result.done.error, result.done.error_detail + assert result.output == "hello context!" + + @frozen class SetupState: fut: "Future[Done]"