Skip to content

Commit b92a10d

Browse files
authored
Add default CUDA versions to torch in fast boots (#2324)
* If the user doesn’t select a CUDA version, select the latest one from the compatibility matrix
1 parent a5fe8c6 commit b92a10d

File tree

9 files changed

+1442
-18
lines changed

9 files changed

+1442
-18
lines changed

pkg/config/config.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,11 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
305305

306306
// Load python_requirements into memory to simplify reading it multiple times
307307
if c.Build.PythonRequirements != "" {
308-
c.Build.pythonRequirementsContent, err = requirements.ReadRequirements(path.Join(projectDir, c.Build.PythonRequirements))
308+
requirementsFilePath := c.Build.PythonRequirements
309+
if !strings.HasPrefix(requirementsFilePath, "/") {
310+
requirementsFilePath = path.Join(projectDir, c.Build.PythonRequirements)
311+
}
312+
c.Build.pythonRequirementsContent, err = requirements.ReadRequirements(requirementsFilePath)
309313
if err != nil {
310314
errs = append(errs, fmt.Errorf("Failed to open python_requirements file: %w", err))
311315
}

pkg/config/config_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"os"
66
"path"
7+
"path/filepath"
78
"testing"
89

910
"github.com/hashicorp/go-version"
@@ -730,3 +731,22 @@ func TestConfigMarshal(t *testing.T) {
730731
predict: ""
731732
`, string(data))
732733
}
734+
735+
func TestAbsolutePathInPythonRequirements(t *testing.T) {
736+
dir := t.TempDir()
737+
requirementsFilePath := filepath.Join(dir, "requirements.txt")
738+
err := os.WriteFile(requirementsFilePath, []byte("torch==2.5.0"), 0o644)
739+
require.NoError(t, err)
740+
config := &Config{
741+
Build: &Build{
742+
GPU: true,
743+
PythonVersion: "3.8",
744+
PythonRequirements: requirementsFilePath,
745+
},
746+
}
747+
err = config.ValidateAndComplete(dir)
748+
require.NoError(t, err)
749+
torchVersion, ok := config.TorchVersion()
750+
require.Equal(t, torchVersion, "2.5.0")
751+
require.True(t, ok)
752+
}

pkg/dockerfile/env.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
)
1010

1111
const CogletVersionEnvVarName = "R8_COGLET_VERSION"
12+
const MonobaseMatrixHostVarName = "R8_MONOBASE_MATRIX_HOST"
13+
const MonobaseMatrixSchemeVarName = "R8_MONOBASE_MATRIX_SCHEME"
1214

1315
func envLineFromConfig(c *config.Config) (string, error) {
1416
vars := c.ParsedEnvironment()
@@ -32,3 +34,19 @@ func CogletVersionFromEnvironment() string {
3234
}
3335
return host
3436
}
37+
38+
func MonobaseMatrixHostFromEnvironment() string {
39+
host := os.Getenv(MonobaseMatrixHostVarName)
40+
if host == "" {
41+
host = "raw.githubusercontent.com"
42+
}
43+
return host
44+
}
45+
46+
func MonobaseMatrixSchemeFromEnvironment() string {
47+
scheme := os.Getenv(MonobaseMatrixSchemeVarName)
48+
if scheme == "" {
49+
scheme = "https"
50+
}
51+
return scheme
52+
}

pkg/dockerfile/env_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ func TestCogletVersionFromEnvironment(t *testing.T) {
1111
t.Setenv(CogletVersionEnvVarName, cogletVersion)
1212
require.Equal(t, CogletVersionFromEnvironment(), cogletVersion)
1313
}
14+
15+
func TestMonobaseMatrixHostFromEnvironment(t *testing.T) {
16+
const monobaseMatrixHost = "localhost"
17+
t.Setenv(MonobaseMatrixHostVarName, monobaseMatrixHost)
18+
require.Equal(t, MonobaseMatrixHostFromEnvironment(), monobaseMatrixHost)
19+
}
20+
21+
func TestMonobaseMatrixSchemeFromEnvironment(t *testing.T) {
22+
const monobaseMatrixScheme = "http"
23+
t.Setenv(MonobaseMatrixSchemeVarName, monobaseMatrixScheme)
24+
require.Equal(t, MonobaseMatrixSchemeFromEnvironment(), monobaseMatrixScheme)
25+
}

pkg/dockerfile/fast_generator.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,24 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
232232
"ENV R8_COG_VERSION=" + CogletVersionFromEnvironment(),
233233
}...)
234234

235+
torchVersion, ok := g.Config.TorchVersion()
236+
if ok {
237+
if !CheckMajorMinorPatch(torchVersion) {
238+
return nil, fmt.Errorf("Torch version must be <major>.<minor>.<patch>: %s", strings.Join(g.matrix.TorchVersions, ", "))
239+
}
240+
envs = append(envs, []string{
241+
"ENV R8_TORCH_VERSION=" + torchVersion,
242+
}...)
243+
}
244+
245+
console.Infof("OK: %v", ok)
246+
console.Infof("Torch Version: %s", torchVersion)
247+
235248
if g.Config.Build.GPU {
236249
cudaVersion := g.Config.Build.CUDA
250+
if cudaVersion == "" && ok {
251+
cudaVersion = g.matrix.DefaultCUDAVersion(torchVersion)
252+
}
237253
cudnnVersion := g.Config.Build.CuDNN
238254
if cudnnVersion == "" {
239255
cudnnVersion = g.matrix.DefaultCudnnVersion()
@@ -261,16 +277,6 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
261277
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
262278
}...)
263279

264-
torchVersion, ok := g.Config.TorchVersion()
265-
if ok {
266-
if !CheckMajorMinorPatch(torchVersion) {
267-
return nil, fmt.Errorf("Torch version must be <major>.<minor>.<patch>: %s", strings.Join(g.matrix.TorchVersions, ", "))
268-
}
269-
envs = append(envs, []string{
270-
"ENV R8_TORCH_VERSION=" + torchVersion,
271-
}...)
272-
}
273-
274280
if !g.matrix.IsSupported(g.Config.Build.PythonVersion, torchVersion, g.Config.Build.CUDA) {
275281
return nil, fmt.Errorf(
276282
"Unsupported version combination: Python=%s, Torch=%s, CUDA=%s",

pkg/dockerfile/fast_generator_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/replicate/cog/pkg/config"
1212
"github.com/replicate/cog/pkg/docker/dockertest"
13+
"github.com/replicate/cog/pkg/util/console"
1314
)
1415

1516
func writeRequirements(t *testing.T, req string) string {
@@ -265,3 +266,45 @@ func TestValidateConfigWithBuildRunItems(t *testing.T) {
265266
err = generator.validateConfig()
266267
require.Error(t, err)
267268
}
269+
270+
func TestTorchVersionDefaultCUDA(t *testing.T) {
271+
dir := t.TempDir()
272+
build := config.Build{
273+
PythonVersion: "3.10",
274+
PythonRequirements: writeRequirements(t, "torch==2.5.0"),
275+
GPU: true,
276+
}
277+
config := config.Config{
278+
Build: &build,
279+
}
280+
err := config.ValidateAndComplete(dir)
281+
require.NoError(t, err)
282+
command := dockertest.NewMockCommand()
283+
284+
// Create matrix
285+
matrix := MonobaseMatrix{
286+
Id: 1,
287+
CudaVersions: []string{"12.4"},
288+
CudnnVersions: []string{"1.0"},
289+
PythonVersions: []string{"3.10"},
290+
TorchVersions: []string{"2.5.0"},
291+
Venvs: []MonobaseVenv{
292+
{
293+
Python: "3.10",
294+
Torch: "2.5.0",
295+
Cuda: "12.4",
296+
},
297+
},
298+
TorchCUDAs: map[string][]string{
299+
"2.5.0": {"12.4"},
300+
},
301+
}
302+
303+
generator, err := NewFastGenerator(&config, dir, command, &matrix, true)
304+
require.NoError(t, err)
305+
dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context())
306+
require.NoError(t, err)
307+
console.Info(dockerfile)
308+
dockerfileLines := strings.Split(dockerfile, "\n")
309+
require.Equal(t, "ENV R8_CUDA_VERSION=12.4", dockerfileLines[4])
310+
}

pkg/dockerfile/monobase_matrix.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@ import (
1111
)
1212

1313
type MonobaseMatrix struct {
14-
Id int `json:"id"`
15-
CudaVersions []string `json:"cuda_versions"`
16-
CudnnVersions []string `json:"cudnn_versions"`
17-
PythonVersions []string `json:"python_versions"`
18-
TorchVersions []string `json:"torch_versions"`
19-
Venvs []MonobaseVenv `json:"venvs"`
14+
Id int `json:"id"`
15+
CudaVersions []string `json:"cuda_versions"`
16+
CudnnVersions []string `json:"cudnn_versions"`
17+
PythonVersions []string `json:"python_versions"`
18+
TorchVersions []string `json:"torch_versions"`
19+
Venvs []MonobaseVenv `json:"venvs"`
20+
TorchCUDAs map[string][]string `json:"torch_cudas"`
2021
}
2122

2223
func NewMonobaseMatrix(client *http.Client) (*MonobaseMatrix, error) {
23-
resp, err := client.Get("https://raw.githubusercontent.com/replicate/monobase/refs/heads/main/matrix.json")
24+
resp, err := client.Get(MonobaseMatrixSchemeFromEnvironment() + "://" + MonobaseMatrixHostFromEnvironment() + "/replicate/monobase/refs/heads/main/matrix.json")
2425
if err != nil {
2526
return nil, err
2627
}
@@ -60,3 +61,8 @@ func (m MonobaseMatrix) IsSupported(python string, torch string, cuda string) bo
6061
}
6162
return slices.Contains(m.Venvs, MonobaseVenv{Python: python, Torch: torch, Cuda: cuda})
6263
}
64+
65+
func (m MonobaseMatrix) DefaultCUDAVersion(torch string) string {
66+
cudas := m.TorchCUDAs[torch]
67+
return cudas[len(cudas)-1]
68+
}

0 commit comments

Comments
 (0)