Skip to content

Commit 7ef7014

Browse files
Use correct CUDA base image
Signed-off-by: andreasjansson <[email protected]>
1 parent 19c1314 commit 7ef7014

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

pkg/docker/generate.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func (g *DockerfileGenerator) baseImage() (string, error) {
7979
case "cpu":
8080
return "ubuntu:20.04", nil
8181
case "gpu":
82-
return g.gpuBaseImage()
82+
return g.Config.CUDABaseImageTag()
8383
}
8484
return "", fmt.Errorf("Invalid architecture: %s", g.Arch)
8585
}
@@ -89,11 +89,6 @@ func (g *DockerfileGenerator) preamble() string {
8989
return "ENV DEBIAN_FRONTEND=noninteractive"
9090
}
9191

92-
func (g *DockerfileGenerator) gpuBaseImage() (string, error) {
93-
// TODO: return correct ubuntu version for tf / torch
94-
return "nvidia/cuda:11.0-devel-ubuntu20.04", nil
95-
}
96-
9792
func (g *DockerfileGenerator) aptInstalls() (string, error) {
9893
packages := g.Config.Environment.SystemPackages
9994
if len(packages) == 0 {

pkg/docker/generate_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func TestGenerateEmpty(t *testing.T) {
5555
model: infer.py:Model
5656
`))
5757
require.NoError(t, err)
58+
require.NoError(t, conf.ValidateAndCompleteConfig())
5859

5960
expectedCPU := `FROM ubuntu:20.04
6061
ENV DEBIAN_FRONTEND=noninteractive
@@ -64,7 +65,7 @@ COPY . /code
6465
WORKDIR /code
6566
CMD ["python", "-c", "from infer import Model; Model().start_server()"]`
6667

67-
expectedGPU := `FROM nvidia/cuda:11.0-devel-ubuntu20.04
68+
expectedGPU := `FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu16.04
6869
ENV DEBIAN_FRONTEND=noninteractive
6970
` + installPython("3.8") + installCog() + `
7071
RUN ### --> Copying code
@@ -96,9 +97,7 @@ environment:
9697
model: infer.py:Model
9798
`))
9899
require.NoError(t, err)
99-
100-
err = conf.ValidateAndCompleteConfig()
101-
require.NoError(t, err)
100+
require.NoError(t, conf.ValidateAndCompleteConfig())
102101

103102
expectedCPU := `FROM ubuntu:20.04
104103
ENV DEBIAN_FRONTEND=noninteractive
@@ -115,7 +114,7 @@ COPY . /code
115114
WORKDIR /code
116115
CMD ["python", "-c", "from infer import Model; Model().start_server()"]`
117116

118-
expectedGPU := `FROM nvidia/cuda:11.0-devel-ubuntu20.04
117+
expectedGPU := `FROM nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04
119118
ENV DEBIAN_FRONTEND=noninteractive
120119
` + installPython("3.8") + `RUN ### --> Installing system packages
121120
RUN apt-get update -qq && apt-get install -qy ffmpeg cowsay && rm -rf /var/lib/apt/lists/*

pkg/model/compatibility.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ import (
1212
"github.com/replicate/cog/pkg/version"
1313
)
1414

15+
// TODO(andreas): check tf/py versions. tf 1.5.0 didn't install on py 3.8
16+
// TODO(andreas): support more tf versions. No matching tensorflow CPU package for version 1.15.4, etc.
17+
// TODO(andreas): allow user to install versions that aren't compatible
18+
// TODO(andreas): allow user to install tf cpu package on gpu
19+
1520
type TFCompatibility struct {
1621
TF string
1722
TFCPUPackage string
@@ -105,13 +110,13 @@ var CUDABaseImages []CUDABaseImage
105110

106111
func init() {
107112
if err := json.Unmarshal(tfCompatibilityMatrixData, &TFCompatibilityMatrix); err != nil {
108-
console.Fatal("Failed to load embedded Tensorflow compatibility matrix: %s", err)
113+
console.Fatalf("Failed to load embedded Tensorflow compatibility matrix: %s", err)
109114
}
110115
if err := json.Unmarshal(torchCompatibilityMatrixData, &TorchCompatibilityMatrix); err != nil {
111-
console.Fatal("Failed to load embedded PyTorch compatibility matrix: %s", err)
116+
console.Fatalf("Failed to load embedded PyTorch compatibility matrix: %s", err)
112117
}
113118
if err := json.Unmarshal(cudaBaseImageTagsData, &CUDABaseImages); err != nil {
114-
console.Fatal("Failed to load embedded CUDA base images: %s", err)
119+
console.Fatalf("Failed to load embedded CUDA base images: %s", err)
115120
}
116121
}
117122

@@ -186,10 +191,11 @@ func latestCuDNNForCUDA(cuda string) string {
186191
func latestTF() TFCompatibility {
187192
var latest *TFCompatibility
188193
for _, compat := range TFCompatibilityMatrix {
194+
compat := compat
189195
if latest == nil {
190196
latest = &compat
191197
} else {
192-
greater, err := versionGreater(latest.TF, compat.TF)
198+
greater, err := versionGreater(compat.TF, latest.TF)
193199
if err != nil {
194200
// should never happen
195201
panic(fmt.Sprintf("Invalid tensorflow version: %s", err))

0 commit comments

Comments
 (0)