-
Notifications
You must be signed in to change notification settings - Fork 689
Expand file tree
/
Copy pathmigrator_v1_v1fast_test.go
More file actions
99 lines (86 loc) · 2.56 KB
/
migrator_v1_v1fast_test.go
File metadata and controls
99 lines (86 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package migrate
import (
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/replicate/cog/pkg/requirements"
)
func TestMigrate(t *testing.T) {
// Set our new working directory to a temp directory
originalDir, err := os.Getwd()
require.NoError(t, err)
defer func() {
err := os.Chdir(originalDir)
require.NoError(t, err)
}()
dir := t.TempDir()
err = os.Chdir(dir)
require.NoError(t, err)
// Write our test configs/code
configFilepath := filepath.Join(dir, "cog.yaml")
file, err := os.Create(configFilepath)
require.NoError(t, err)
_, err = file.WriteString(`build:
python_version: "3.11"
fast: true
python_packages:
- "pillow"
run:
- command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget
predict: "predict.py:Predictor"
`)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
pythonFilepath := filepath.Join(dir, "predict.py")
file, err = os.Create(pythonFilepath)
require.NoError(t, err)
_, err = file.WriteString(`from cog import BasePredictor, Input
class Predictor(BasePredictor):
def predict(self, s: str = Input(description="My Input Description", default=None)) -> str:
return "hello " + s
`)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
// Perform the migration
migrator := NewMigratorV1ToV1Fast(false)
err = migrator.Migrate(t.Context(), "cog.yaml")
require.NoError(t, err)
// Check config output
file, err = os.Open(configFilepath)
require.NoError(t, err)
content, err := io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, `build:
python_version: "3.11"
python_requirements: requirements.txt
fast: true
predict: predict.py:Predictor
`, string(content))
err = file.Close()
require.NoError(t, err)
// Check python code output
file, err = os.Open(pythonFilepath)
require.NoError(t, err)
content, err = io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, `from typing import Optional
from cog import BasePredictor, Input
class Predictor(BasePredictor):
def predict(self, s: Optional[str] = Input(description="My Input Description", default=None)) -> str:
return "hello " + s
`, string(content))
err = file.Close()
require.NoError(t, err)
// Check requirements.txt
file, err = os.Open(filepath.Join(dir, requirements.RequirementsFile))
require.NoError(t, err)
content, err = io.ReadAll(file)
require.NoError(t, err)
require.Equal(t, `pillow`, string(content))
err = file.Close()
require.NoError(t, err)
}