Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pkg/config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ func GetProjectDir(configFilename string) (string, error) {
// Loads and instantiates a Config object
// customDir can be specified to override the default - current working directory
func GetConfig(configFilename string) (*Config, string, error) {
config, rootDir, err := GetRawConfig(configFilename)
if err != nil {
return nil, "", err
}
err = config.ValidateAndComplete(rootDir)
return config, rootDir, err
}

func GetRawConfig(configFilename string) (*Config, string, error) {
// Find the root project directory
rootDir, err := GetProjectDir(configFilename)

Expand All @@ -38,8 +47,6 @@ func GetConfig(configFilename string) (*Config, string, error) {
return nil, "", err
}

err = config.ValidateAndComplete(rootDir)

return config, rootDir, err
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/migrate/migrator_v1_v1fast.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast {
}

func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context, configFilename string) error {
cfg, projectDir, err := config.GetConfig(configFilename)
cfg, projectDir, err := config.GetRawConfig(configFilename)
if err != nil {
return err
}
Expand Down Expand Up @@ -175,10 +175,6 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configF
cfg.Build = config.DefaultConfig().Build
}
cfg.Build.Fast = true
err := cfg.ValidateAndComplete("")
if err != nil {
return err
}
data, err := yaml.Marshal(cfg)
if err != nil {
return err
Expand Down Expand Up @@ -280,7 +276,7 @@ func (g *MigratorV1ToV1Fast) runPythonScript(ctx context.Context, file *zip.File
return err
}
newContent := out.String()
if newContent == "" {
if strings.TrimSpace(newContent) == "" {
return nil
}
accept := true
Expand Down
64 changes: 64 additions & 0 deletions pkg/migrate/migrator_v1_v1fast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,67 @@ class Predictor(BasePredictor):
err = file.Close()
require.NoError(t, err)
}

func TestMigrateGPU(t *testing.T) {
// Set our new working directory to a temp directory
originalDir, err := os.Getwd()
require.NoError(t, err)
defer func() {
err := os.Chdir(originalDir)
require.NoError(t, err)
}()
dir := t.TempDir()
err = os.Chdir(dir)
require.NoError(t, err)

// Write our test configs/code
configFilepath := filepath.Join(dir, "cog.yaml")
file, err := os.Create(configFilepath)
require.NoError(t, err)
_, err = file.WriteString(`build:
python_version: "3.11"
fast: true
python_packages:
- "pillow"
run:
- command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget
gpu: true
predict: "predict.py:Predictor"
`)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
pythonFilepath := filepath.Join(dir, "predict.py")
file, err = os.Create(pythonFilepath)
require.NoError(t, err)
_, err = file.WriteString(`from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(self, s: str = Input(description="My Input Description", default=None)) -> str:
return "hello " + s
`)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)

// Perform the migration
migrator := NewMigratorV1ToV1Fast(false)
err = migrator.Migrate(t.Context(), "cog.yaml")
require.NoError(t, err)

// Check config output
file, err = os.Open(configFilepath)
require.NoError(t, err)
content, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, `build:
gpu: true
python_version: "3.11"
python_requirements: requirements.txt
fast: true
predict: predict.py:Predictor
`, string(content))
err = file.Close()
require.NoError(t, err)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
build:
python_version: "3.11"
fast: true
python_packages:
- "pillow"
run:
- command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget
gpu: true
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(self, s: str = Input(description="My Input Description", default=None)) -> str:
return "hello " + s
27 changes: 27 additions & 0 deletions test-integration/test_integration/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,30 @@ def predict(self, s: Optional[str] = Input(description="My Input Description", d
"""
with open(os.path.join(out_dir, "requirements.txt"), encoding="utf8") as handle:
assert handle.read(), "pillow"


def test_migrate_gpu(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/migration-gpu-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
[
"cog",
"migrate",
"--y",
],
cwd=out_dir,
check=True,
capture_output=True,
text=True,
timeout=DEFAULT_TIMEOUT,
)
assert result.returncode == 0
with open(os.path.join(out_dir, "cog.yaml"), encoding="utf8") as handle:
assert handle.read(), """build:
gpu: true
python_version: "3.11"
python_requirements: requirements.txt
fast: true
predict: predict.py:Predictor
"""