Skip to content

Commit 895a28f

Browse files
authored
Revert "Add torch 2.5.0 CUDA 12.4.1 to matrix (#2350)" (#2352)
This reverts commit 92cf6a7.
1 parent 0f4cf3d commit 895a28f

File tree

2 files changed

+4
-20
lines changed

2 files changed

+4
-20
lines changed

pkg/config/torch_compatibility_matrix.json

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,6 @@
115115
"3.13"
116116
]
117117
},
118-
{
119-
"Torch": "2.5.0+cu124",
120-
"Torchvision": "0.20.0",
121-
"Torchaudio": "2.5.0",
122-
"FindLinks": "",
123-
"ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
124-
"CUDA": "12.4.1",
125-
"Pythons": [
126-
"3.9",
127-
"3.10",
128-
"3.11",
129-
"3.12",
130-
"3.13"
131-
]
132-
},
133118
{
134119
"Torch": "2.4.1",
135120
"Torchvision": "0.19.1",

pkg/dockerfile/fast_generator.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"github.com/replicate/cog/pkg/docker"
2020
"github.com/replicate/cog/pkg/docker/command"
2121
"github.com/replicate/cog/pkg/requirements"
22-
"github.com/replicate/cog/pkg/util/version"
2322
"github.com/replicate/cog/pkg/weights"
2423
)
2524

@@ -246,8 +245,8 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
246245
console.Infof("OK: %v", ok)
247246
console.Infof("Torch Version: %s", torchVersion)
248247

249-
cudaVersion := g.Config.Build.CUDA
250248
if g.Config.Build.GPU {
249+
cudaVersion := g.Config.Build.CUDA
251250
if cudaVersion == "" && ok {
252251
cudaVersion = g.matrix.DefaultCUDAVersion(torchVersion)
253252
}
@@ -256,7 +255,7 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
256255
cudnnVersion = g.matrix.DefaultCudnnVersion()
257256
}
258257
if !CheckMajorMinorOnly(cudaVersion) {
259-
cudaVersion = version.StripPatch(cudaVersion)
258+
return nil, fmt.Errorf("CUDA version must be <major>.<minor>, supported versions: %s", strings.Join(g.matrix.CudaVersions, ", "))
260259
}
261260
if !CheckMajorOnly(cudnnVersion) {
262261
return nil, fmt.Errorf("CUDNN version must be <major> only, supported versions: %s", strings.Join(g.matrix.CudnnVersions, ", "))
@@ -278,10 +277,10 @@ func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]strin
278277
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
279278
}...)
280279

281-
if !g.matrix.IsSupported(g.Config.Build.PythonVersion, torchVersion, cudaVersion) {
280+
if !g.matrix.IsSupported(g.Config.Build.PythonVersion, torchVersion, g.Config.Build.CUDA) {
282281
return nil, fmt.Errorf(
283282
"Unsupported version combination: Python=%s, Torch=%s, CUDA=%s",
284-
g.Config.Build.PythonVersion, torchVersion, cudaVersion)
283+
g.Config.Build.PythonVersion, torchVersion, g.Config.Build.CUDA)
285284
}
286285

287286
// The only input to monobase.build are these ENV vars

0 commit comments

Comments
 (0)