From 3ce6e506a7141212624176da11aaf8c73be0fc3e Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 16 May 2025 16:54:15 -0400 Subject: [PATCH 1/2] Call train on class object * If the user makes a training request call the train function on the predictor class --- python/cog/base_predictor.py | 11 +++-- python/cog/server/http.py | 6 ++- python/cog/server/runner.py | 11 +++-- python/cog/server/worker.py | 11 ++++- python/tests/server/conftest.py | 1 + python/tests/server/test_runner.py | 70 ++++++++++++++++++------------ 6 files changed, 73 insertions(+), 37 deletions(-) diff --git a/python/cog/base_predictor.py b/python/cog/base_predictor.py index 3cc5899603..6d63858e69 100644 --- a/python/cog/base_predictor.py +++ b/python/cog/base_predictor.py @@ -1,10 +1,9 @@ -from abc import ABC, abstractmethod from typing import Any, Optional from .types import Weights -class BasePredictor(ABC): +class BasePredictor: def setup( self, weights: Optional[Weights] = None, # pylint: disable=unused-argument @@ -14,8 +13,14 @@ def setup( """ return - @abstractmethod def predict(self, **kwargs: Any) -> Any: """ Run a single prediction on the model """ + raise NotImplementedError("predict has not been implemented by parent class.") + + def train(self, **kwargs: Any) -> Any: + """ + Run a single train on the model + """ + raise NotImplementedError("train has not been implemented by parent class.") diff --git a/python/cog/server/http.py b/python/cog/server/http.py index ee7fe2c0a5..51f5e972b8 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -167,6 +167,7 @@ async def start_shutdown() -> Any: worker = make_worker( predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async, + is_train=False if mode == Mode.PREDICT else True, max_concurrency=cog_config.max_concurrency, ) runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency) @@ -238,6 +239,7 @@ async def train( request=request, response_type=TrainingResponse, respond_async=respond_async, + is_train=True, ) @app.put( @@ -286,6 +288,7 @@ async def train_idempotent( request=request, response_type=TrainingResponse, respond_async=respond_async, + is_train=True, ) @app.post("/trainings/{training_id}/cancel") @@ -420,6 +423,7 @@ async def _predict( request: Optional[PredictionRequest], response_type: Type[schema.PredictionResponse], respond_async: bool = False, + is_train: bool = False, ) -> Response: # [compat] If no body is supplied, assume that this model can be run # with empty input. This will throw a ValidationError if that's not @@ -439,7 +443,7 @@ async def _predict( task_kwargs["upload_url"] = upload_url try: - predict_task = runner.predict(request, task_kwargs=task_kwargs) + predict_task = runner.predict(request, is_train, task_kwargs=task_kwargs) except RunnerBusyError: return JSONResponse( {"detail": "Already running a prediction"}, status_code=409 diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 41e59ce86b..e7055abb27 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -87,6 +87,7 @@ def setup(self) -> "SetupTask": def predict( self, prediction: schema.PredictionRequest, + is_train: bool, task_kwargs: Optional[Dict[str, Any]] = None, ) -> "PredictTask": self._raise_if_busy() @@ -97,7 +98,7 @@ def predict( if tag is None: tag = uuid.uuid4().hex - task = PredictTask(prediction, **task_kwargs) + task = PredictTask(prediction, is_train, **task_kwargs) with self._predict_tasks_lock: self._predict_tasks[tag] = task @@ -281,11 +282,13 @@ class PredictTask(Task[schema.PredictionResponse]): def __init__( self, prediction_request: schema.PredictionRequest, + is_train: bool, upload_url: Optional[str] = None, ) -> None: + self._is_train = is_train self._log = log.bind(prediction_id=prediction_request.id) - self._log.info("starting prediction") + self._log.info("starting " + ("prediction" if not is_train else "train")) self._fut: "Optional[Future[Done]]" = None @@ -324,7 +327,7 @@ def result(self) -> schema.PredictionResponse: return self._p def track(self, fut: "Future[Done]") -> None: - self._log.info("started prediction") + self._log.info("started " + ("prediction" if not self._is_train else "train")) # HACK: don't send an initial webhook if we're trying to optimize for # latency (this guarantees that the first output webhook won't be @@ -393,7 +396,7 @@ def set_metric(self, key: str, value: Union[float, int]) -> None: self._p.metrics[key] = value def succeeded(self) -> None: - self._log.info("prediction succeeded") + self._log.info(("prediction" if not self._is_train else "train") + " succeeded") self._p.status = schema.Status.SUCCEEDED self._set_completed_at() # These have been set already: this is to convince the typechecker of diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 78de3edec7..963dd83a9d 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -33,6 +33,7 @@ from ..predictor import ( extract_setup_weights, get_predict, + get_train, has_setup_weights, load_predictor_from_ref, ) @@ -399,6 +400,7 @@ def __init__( predictor_ref: str, *, is_async: bool, + is_train: bool, events: Connection, max_concurrency: int = 1, tee_output: bool = True, @@ -415,6 +417,7 @@ def __init__( # for synchronous predictors only! async predictors use current_scope()._tag instead self._sync_tag: Optional[str] = None self._has_async_predictor = is_async + self._is_train = is_train super().__init__() @@ -451,7 +454,11 @@ def run(self) -> None: if not self._validate_predictor(redirector): return - predict = get_predict(self._predictor) + predict = ( + get_predict(self._predictor) + if not self._is_train + else get_train(self._predictor) + ) if self._has_async_predictor: assert isinstance(redirector, SimpleStreamRedirector) @@ -853,6 +860,7 @@ def make_worker( predictor_ref: str, *, is_async: bool, + is_train: bool, tee_output: bool = True, max_concurrency: int = 1, ) -> Worker: @@ -860,6 +868,7 @@ def make_worker( child = _ChildWorker( predictor_ref, is_async=is_async, + is_train=is_train, events=child_conn, tee_output=tee_output, max_concurrency=max_concurrency, diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 4d5dfd8bb7..a1972034c2 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -185,6 +185,7 @@ def worker(request): w = make_worker( predictor_ref=ref, is_async=request.param.is_async, + is_train=False, tee_output=False, max_concurrency=request.param.max_concurrency, ) diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index a001dd9eb9..b95a59d32e 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -155,7 +155,7 @@ def test_prediction_runner_predict_success(): r.setup() w.run_setup([Done()]) - task = r.predict(PredictionRequest(input={"text": "giraffes"})) + task = r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) assert w.last_prediction_payload == {"text": "giraffes"} assert task.result.input == {"text": "giraffes"} assert task.result.status == Status.PROCESSING @@ -174,7 +174,7 @@ def test_prediction_runner_predict_failure(): r.setup() w.run_setup([Done()]) - task = r.predict(PredictionRequest(input={"text": "giraffes"})) + task = r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) assert w.last_prediction_payload == {"text": "giraffes"} assert task.result.input == {"text": "giraffes"} assert task.result.status == Status.PROCESSING @@ -191,7 +191,7 @@ def test_prediction_runner_predict_exception(): r.setup() w.run_setup([Done()]) - task = r.predict(PredictionRequest(input={"text": "giraffes"})) + task = r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) assert w.last_prediction_payload == {"text": "giraffes"} assert task.result.input == {"text": "giraffes"} assert task.result.status == Status.PROCESSING @@ -216,7 +216,7 @@ def test_prediction_runner_predict_before_setup(): r = PredictionRunner(worker=w) with pytest.raises(RunnerBusyError): - r.predict(PredictionRequest(input={"text": "giraffes"})) + r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) def test_prediction_runner_predict_before_setup_completes(): @@ -226,7 +226,7 @@ def test_prediction_runner_predict_before_setup_completes(): r.setup() with pytest.raises(RunnerBusyError): - r.predict(PredictionRequest(input={"text": "giraffes"})) + r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) def test_prediction_runner_predict_before_predict_completes(): @@ -236,10 +236,10 @@ def test_prediction_runner_predict_before_predict_completes(): r.setup() w.run_setup([Done()]) - r.predict(PredictionRequest(input={"text": "giraffes"})) + r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) with pytest.raises(RunnerBusyError): - r.predict(PredictionRequest(input={"text": "giraffes"})) + r.predict(PredictionRequest(input={"text": "giraffes"}), is_train=False) def test_prediction_runner_predict_after_predict_completes(): @@ -249,10 +249,10 @@ def test_prediction_runner_predict_after_predict_completes(): r.setup() w.run_setup([Done()]) - r.predict(PredictionRequest(id="p-1", input={"text": "giraffes"})) + r.predict(PredictionRequest(id="p-1", input={"text": "giraffes"}), is_train=False) w.run_predict([Done()], id="p-1") - r.predict(PredictionRequest(id="p-2", input={"text": "elephants"})) + r.predict(PredictionRequest(id="p-2", input={"text": "elephants"}), is_train=False) w.run_predict([Done()], id="p-2") assert w.last_prediction_payload == {"text": "elephants"} @@ -270,7 +270,7 @@ def test_prediction_runner_is_busy(): w.run_setup([Done()]) assert not r.is_busy() - r.predict(PredictionRequest(input={"text": "elephants"})) + r.predict(PredictionRequest(input={"text": "elephants"}), is_train=False) assert r.is_busy() w.run_predict([Done()]) @@ -289,13 +289,13 @@ def test_prediction_runner_is_busy_concurrency(): w.run_setup([Done()]) assert not r.is_busy() - r.predict(PredictionRequest(id="1", input={"text": "elephants"})) + r.predict(PredictionRequest(id="1", input={"text": "elephants"}), is_train=False) assert not r.is_busy() - r.predict(PredictionRequest(id="2", input={"text": "elephants"})) + r.predict(PredictionRequest(id="2", input={"text": "elephants"}), is_train=False) assert not r.is_busy() - r.predict(PredictionRequest(id="3", input={"text": "elephants"})) + r.predict(PredictionRequest(id="3", input={"text": "elephants"}), is_train=False) assert r.is_busy() w.run_predict([Done()], id="1") @@ -309,7 +309,9 @@ def test_prediction_runner_predict_cancelation(): r.setup() w.run_setup([Done()]) - task = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + task = r.predict( + PredictionRequest(id="abcd1234", input={"text": "giraffes"}), is_train=False + ) with pytest.raises(ValueError): r.cancel(None) @@ -332,10 +334,14 @@ def test_prediction_runner_predict_cancelation_multiple_predictions(): r.setup() w.run_setup([Done()]) - task1 = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + task1 = r.predict( + PredictionRequest(id="abcd1234", input={"text": "giraffes"}), is_train=False + ) w.run_predict([Done()]) - task2 = r.predict(PredictionRequest(id="defg6789", input={"text": "elephants"})) + task2 = r.predict( + PredictionRequest(id="defg6789", input={"text": "elephants"}), is_train=False + ) with pytest.raises(UnknownPredictionError): r.cancel("abcd1234") @@ -351,9 +357,13 @@ def test_prediction_runner_predict_cancelation_concurrent_predictions(): r.setup() w.run_setup([Done()]) - task1 = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + task1 = r.predict( + PredictionRequest(id="abcd1234", input={"text": "giraffes"}), is_train=False + ) - task2 = r.predict(PredictionRequest(id="defg6789", input={"text": "elephants"})) + task2 = r.predict( + PredictionRequest(id="defg6789", input={"text": "elephants"}), is_train=False + ) r.cancel("abcd1234") w.run_predict([Done()], id="defg6789") @@ -362,7 +372,9 @@ def test_prediction_runner_predict_cancelation_concurrent_predictions(): def test_prediction_runner_setup_e2e(): - w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False) + w = make_worker( + predictor_ref=_fixture_path("sleep"), is_async=False, is_train=False + ) r = PredictionRunner(worker=w) try: @@ -378,12 +390,14 @@ def test_prediction_runner_setup_e2e(): def test_prediction_runner_predict_e2e(): - w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False) + w = make_worker( + predictor_ref=_fixture_path("sleep"), is_async=False, is_train=False + ) r = PredictionRunner(worker=w) try: r.setup().wait(timeout=5) - task = r.predict(PredictionRequest(input={"sleep": 0.1})) + task = r.predict(PredictionRequest(input={"sleep": 0.1}), is_train=False) task.wait(timeout=1) finally: w.shutdown() @@ -456,7 +470,7 @@ def test_predict_task(): output_file_prefix=None, webhook=None, ) - t = PredictTask(p) + t = PredictTask(p, False) assert t.result.status == Status.PROCESSING assert t.result.output is None @@ -476,7 +490,7 @@ def test_predict_task_multi(): output_file_prefix=None, webhook=None, ) - t = PredictTask(p) + t = PredictTask(p, False) assert t.result.status == Status.PROCESSING assert t.result.output is None @@ -514,7 +528,7 @@ def test_predict_task_webhook_sender(): output_file_prefix=None, webhook="https://a.url.honest", ) - t = PredictTask(p) + t = PredictTask(p, False) t._webhook_sender = mock.Mock() t.track(Future()) @@ -552,7 +566,7 @@ def test_predict_task_webhook_sender_intermediate(): output_file_prefix=None, webhook="https://a.url.honest", ) - t = PredictTask(p) + t = PredictTask(p, False) t._webhook_sender = mock.Mock() t.track(Future()) @@ -574,7 +588,7 @@ def test_predict_task_webhook_sender_intermediate_multi(): output_file_prefix=None, webhook="https://a.url.honest", ) - t = PredictTask(p) + t = PredictTask(p, False) t._webhook_sender = mock.Mock() t.track(Future()) @@ -643,7 +657,7 @@ def test_predict_task_file_uploads(): output_file_prefix=None, webhook=None, ) - t = PredictTask(p, upload_url="https://a.url.honest") + t = PredictTask(p, False, upload_url="https://a.url.honest") t._file_uploader = mock.Mock() # in reality this would be a Path object, but in this test we just care it @@ -665,7 +679,7 @@ def test_predict_task_file_uploads_multi(): output_file_prefix=None, webhook=None, ) - t = PredictTask(p, upload_url="https://a.url.honest") + t = PredictTask(p, False, upload_url="https://a.url.honest") t._file_uploader = mock.Mock() t._file_uploader.return_value = [] From f938a42c148247e3560ed10a9e1301bc7aa3d3cf Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Mon, 19 May 2025 14:02:39 -0400 Subject: [PATCH 2/2] Fix integration test --- test-integration/test_integration/test_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-integration/test_integration/test_train.py b/test-integration/test_integration/test_train.py index c7d69968da..6b87d7a7c1 100644 --- a/test-integration/test_integration/test_train.py +++ b/test-integration/test_integration/test_train.py @@ -48,4 +48,4 @@ def test_training_setup_project(tmpdir_factory, cog_binary): assert result.returncode == 0 assert "Trainer is setting up." in str(result.stderr) with open(out_dir / "weights", "r", encoding="utf8") as f: - assert f.read() == "hello predict world" + assert f.read() == "hello train world"