diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 14dde2c4bd..d79590069e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -121,7 +121,7 @@ jobs: name: "Test integration" needs: build-python runs-on: ubuntu-latest-16-cores - timeout-minutes: 20 + timeout-minutes: 25 steps: - uses: actions/checkout@v4 with: diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 8523bc01a5..9af2ebdc5f 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -206,7 +206,12 @@ func isURI(ref *openapi3.Schema) bool { } func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool) error { - console.Info("Running prediction...") + if isTrain { + console.Info("Running training...") + } else { + console.Info("Running prediction...") + } + schema, err := predictor.GetSchema() if err != nil { return err diff --git a/test-integration/test_integration/fixtures/training-setup-project/cog.yaml b/test-integration/test_integration/fixtures/training-setup-project/cog.yaml new file mode 100644 index 0000000000..53db34fbad --- /dev/null +++ b/test-integration/test_integration/fixtures/training-setup-project/cog.yaml @@ -0,0 +1,4 @@ +build: + python_version: "3.12" +train: "train.py:Trainer" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/training-setup-project/predict.py b/test-integration/test_integration/fixtures/training-setup-project/predict.py new file mode 100644 index 0000000000..44f6992b01 --- /dev/null +++ b/test-integration/test_integration/fixtures/training-setup-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/fixtures/training-setup-project/train.py b/test-integration/test_integration/fixtures/training-setup-project/train.py new file mode 100644 index 0000000000..a3712269fb --- /dev/null +++ b/test-integration/test_integration/fixtures/training-setup-project/train.py @@ -0,0 +1,14 @@ +from cog import BasePredictor + + +class Trainer(BasePredictor): + def setup(self) -> None: + print("Trainer is setting up.") + + def train(self, s: str) -> str: + print("Trainer.train called.") + return "hello train " + s + + def predict(self, s: str) -> str: + print("Trainer.predict called.") + return "hello predict " + s diff --git a/test-integration/test_integration/test_train.py b/test-integration/test_integration/test_train.py index 22b61399b1..c7d69968da 100644 --- a/test-integration/test_integration/test_train.py +++ b/test-integration/test_integration/test_train.py @@ -32,3 +32,20 @@ def test_train_pydantic2(tmpdir_factory, cog_binary): capture_output=True, ) assert result.returncode == 0 + + +def test_training_setup_project(tmpdir_factory, cog_binary): + project_dir = Path(__file__).parent / "fixtures/training-setup-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + result = subprocess.run( + [cog_binary, "train", "--debug", "-i", "s=world"], + cwd=out_dir, + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 0 + assert "Trainer is setting up." in str(result.stderr) + with open(out_dir / "weights", "r", encoding="utf8") as f: + assert f.read() == "hello predict world"