Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/cog/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from .types import Weights


class BasePredictor(ABC):
class BasePredictor:
def setup(
self,
weights: Optional[Weights] = None, # pylint: disable=unused-argument
Expand All @@ -14,8 +13,14 @@ def setup(
"""
return

@abstractmethod
def predict(self, **kwargs: Any) -> Any:
"""
Run a single prediction on the model
"""
raise NotImplementedError("predict has not been implemented by parent class.")

def train(self, **kwargs: Any) -> Any:
"""
Run a single train on the model
"""
raise NotImplementedError("train has not been implemented by parent class.")
6 changes: 5 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ async def start_shutdown() -> Any:
worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode),
is_async=is_async,
is_train=False if mode == Mode.PREDICT else True,
max_concurrency=cog_config.max_concurrency,
)
runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency)
Expand Down Expand Up @@ -238,6 +239,7 @@ async def train(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
is_train=True,
)

@app.put(
Expand Down Expand Up @@ -286,6 +288,7 @@ async def train_idempotent(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
is_train=True,
)

@app.post("/trainings/{training_id}/cancel")
Expand Down Expand Up @@ -420,6 +423,7 @@ async def _predict(
request: Optional[PredictionRequest],
response_type: Type[schema.PredictionResponse],
respond_async: bool = False,
is_train: bool = False,
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
# with empty input. This will throw a ValidationError if that's not
Expand All @@ -439,7 +443,7 @@ async def _predict(
task_kwargs["upload_url"] = upload_url

try:
predict_task = runner.predict(request, task_kwargs=task_kwargs)
predict_task = runner.predict(request, is_train, task_kwargs=task_kwargs)
except RunnerBusyError:
return JSONResponse(
{"detail": "Already running a prediction"}, status_code=409
Expand Down
11 changes: 7 additions & 4 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setup(self) -> "SetupTask":
def predict(
self,
prediction: schema.PredictionRequest,
is_train: bool,
task_kwargs: Optional[Dict[str, Any]] = None,
) -> "PredictTask":
self._raise_if_busy()
Expand All @@ -97,7 +98,7 @@ def predict(
if tag is None:
tag = uuid.uuid4().hex

task = PredictTask(prediction, **task_kwargs)
task = PredictTask(prediction, is_train, **task_kwargs)

with self._predict_tasks_lock:
self._predict_tasks[tag] = task
Expand Down Expand Up @@ -281,11 +282,13 @@ class PredictTask(Task[schema.PredictionResponse]):
def __init__(
self,
prediction_request: schema.PredictionRequest,
is_train: bool,
upload_url: Optional[str] = None,
) -> None:
self._is_train = is_train
self._log = log.bind(prediction_id=prediction_request.id)

self._log.info("starting prediction")
self._log.info("starting " + ("prediction" if not is_train else "train"))

self._fut: "Optional[Future[Done]]" = None

Expand Down Expand Up @@ -324,7 +327,7 @@ def result(self) -> schema.PredictionResponse:
return self._p

def track(self, fut: "Future[Done]") -> None:
self._log.info("started prediction")
self._log.info("started " + ("prediction" if not self._is_train else "train"))

# HACK: don't send an initial webhook if we're trying to optimize for
# latency (this guarantees that the first output webhook won't be
Expand Down Expand Up @@ -393,7 +396,7 @@ def set_metric(self, key: str, value: Union[float, int]) -> None:
self._p.metrics[key] = value

def succeeded(self) -> None:
self._log.info("prediction succeeded")
self._log.info(("prediction" if not self._is_train else "train") + " succeeded")
self._p.status = schema.Status.SUCCEEDED
self._set_completed_at()
# These have been set already: this is to convince the typechecker of
Expand Down
11 changes: 10 additions & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..predictor import (
extract_setup_weights,
get_predict,
get_train,
has_setup_weights,
load_predictor_from_ref,
)
Expand Down Expand Up @@ -399,6 +400,7 @@ def __init__(
predictor_ref: str,
*,
is_async: bool,
is_train: bool,
events: Connection,
max_concurrency: int = 1,
tee_output: bool = True,
Expand All @@ -415,6 +417,7 @@ def __init__(
# for synchronous predictors only! async predictors use current_scope()._tag instead
self._sync_tag: Optional[str] = None
self._has_async_predictor = is_async
self._is_train = is_train

super().__init__()

Expand Down Expand Up @@ -451,7 +454,11 @@ def run(self) -> None:
if not self._validate_predictor(redirector):
return

predict = get_predict(self._predictor)
predict = (
get_predict(self._predictor)
if not self._is_train
else get_train(self._predictor)
)

if self._has_async_predictor:
assert isinstance(redirector, SimpleStreamRedirector)
Expand Down Expand Up @@ -853,13 +860,15 @@ def make_worker(
predictor_ref: str,
*,
is_async: bool,
is_train: bool,
tee_output: bool = True,
max_concurrency: int = 1,
) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
child = _ChildWorker(
predictor_ref,
is_async=is_async,
is_train=is_train,
events=child_conn,
tee_output=tee_output,
max_concurrency=max_concurrency,
Expand Down
1 change: 1 addition & 0 deletions python/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def worker(request):
w = make_worker(
predictor_ref=ref,
is_async=request.param.is_async,
is_train=False,
tee_output=False,
max_concurrency=request.param.max_concurrency,
)
Expand Down
Loading