diff --git a/pkg/config/load.go b/pkg/config/load.go index 3673e051bd..a308ac10cc 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -24,6 +24,15 @@ func GetProjectDir(configFilename string) (string, error) { // Loads and instantiates a Config object // customDir can be specified to override the default - current working directory func GetConfig(configFilename string) (*Config, string, error) { + config, rootDir, err := GetRawConfig(configFilename) + if err != nil { + return nil, "", err + } + err = config.ValidateAndComplete(rootDir) + return config, rootDir, err +} + +func GetRawConfig(configFilename string) (*Config, string, error) { // Find the root project directory rootDir, err := GetProjectDir(configFilename) @@ -38,8 +47,6 @@ func GetConfig(configFilename string) (*Config, string, error) { return nil, "", err } - err = config.ValidateAndComplete(rootDir) - return config, rootDir, err } diff --git a/pkg/migrate/migrator_v1_v1fast.go b/pkg/migrate/migrator_v1_v1fast.go index 6bae0b7976..ce01519c1d 100644 --- a/pkg/migrate/migrator_v1_v1fast.go +++ b/pkg/migrate/migrator_v1_v1fast.go @@ -44,7 +44,7 @@ func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast { } func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context, configFilename string) error { - cfg, projectDir, err := config.GetConfig(configFilename) + cfg, projectDir, err := config.GetRawConfig(configFilename) if err != nil { return err } @@ -175,10 +175,6 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configF cfg.Build = config.DefaultConfig().Build } cfg.Build.Fast = true - err := cfg.ValidateAndComplete("") - if err != nil { - return err - } data, err := yaml.Marshal(cfg) if err != nil { return err @@ -280,7 +276,7 @@ func (g *MigratorV1ToV1Fast) runPythonScript(ctx context.Context, file *zip.File return err } newContent := out.String() - if newContent == "" { + if strings.TrimSpace(newContent) == "" { return nil } accept := true diff --git a/pkg/migrate/migrator_v1_v1fast_test.go b/pkg/migrate/migrator_v1_v1fast_test.go index cd85623724..a45e17a887 100644 --- a/pkg/migrate/migrator_v1_v1fast_test.go +++ b/pkg/migrate/migrator_v1_v1fast_test.go @@ -97,3 +97,67 @@ class Predictor(BasePredictor): err = file.Close() require.NoError(t, err) } + +func TestMigrateGPU(t *testing.T) { + // Set our new working directory to a temp directory + originalDir, err := os.Getwd() + require.NoError(t, err) + defer func() { + err := os.Chdir(originalDir) + require.NoError(t, err) + }() + dir := t.TempDir() + err = os.Chdir(dir) + require.NoError(t, err) + + // Write our test configs/code + configFilepath := filepath.Join(dir, "cog.yaml") + file, err := os.Create(configFilepath) + require.NoError(t, err) + _, err = file.WriteString(`build: + python_version: "3.11" + fast: true + python_packages: + - "pillow" + run: + - command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget + gpu: true +predict: "predict.py:Predictor" +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + pythonFilepath := filepath.Join(dir, "predict.py") + file, err = os.Create(pythonFilepath) + require.NoError(t, err) + _, err = file.WriteString(`from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: str = Input(description="My Input Description", default=None)) -> str: + return "hello " + s +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + + // Perform the migration + migrator := NewMigratorV1ToV1Fast(false) + err = migrator.Migrate(t.Context(), "cog.yaml") + require.NoError(t, err) + + // Check config output + file, err = os.Open(configFilepath) + require.NoError(t, err) + content, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, `build: + gpu: true + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: predict.py:Predictor +`, string(content)) + err = file.Close() + require.NoError(t, err) +} diff --git a/test-integration/test_integration/fixtures/migration-gpu-project/cog.yaml b/test-integration/test_integration/fixtures/migration-gpu-project/cog.yaml new file mode 100644 index 0000000000..d1c89659c4 --- /dev/null +++ b/test-integration/test_integration/fixtures/migration-gpu-project/cog.yaml @@ -0,0 +1,9 @@ +build: + python_version: "3.11" + fast: true + python_packages: + - "pillow" + run: + - command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget + gpu: true +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/migration-gpu-project/predict.py b/test-integration/test_integration/fixtures/migration-gpu-project/predict.py new file mode 100644 index 0000000000..b50e415c63 --- /dev/null +++ b/test-integration/test_integration/fixtures/migration-gpu-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: str = Input(description="My Input Description", default=None)) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_migrate.py b/test-integration/test_integration/test_migrate.py index 236cb69773..d2e6b69527 100644 --- a/test-integration/test_integration/test_migrate.py +++ b/test-integration/test_integration/test_migrate.py @@ -42,3 +42,30 @@ def predict(self, s: Optional[str] = Input(description="My Input Description", d """ with open(os.path.join(out_dir, "requirements.txt"), encoding="utf8") as handle: assert handle.read(), "pillow" + + +def test_migrate_gpu(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/migration-gpu-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + result = subprocess.run( + [ + "cog", + "migrate", + "--y", + ], + cwd=out_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + assert result.returncode == 0 + with open(os.path.join(out_dir, "cog.yaml"), encoding="utf8") as handle: + assert handle.read(), """build: + gpu: true + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: predict.py:Predictor +"""