Skip to content

Commit 3d5b4aa

Browse files
8W9aGmichaeldwan
andauthored
Fix tensorflow not appearing for standard builds (#2326)
* Add integration tests for tensorflow * Move to .requirements.txt * return the container id without the trailing newline * Fix requirements test * Add support for tensorflow as a special package --------- Co-authored-by: Michael Dwan <m@dwan.io>
1 parent 4c8b448 commit 3d5b4aa

8 files changed

Lines changed: 142 additions & 3 deletions

File tree

pkg/docker/docker_command.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func (c *DockerCommand) ContainerStart(ctx context.Context, options command.RunO
347347
return "", err
348348
}
349349

350-
return out.String(), nil
350+
return strings.TrimSpace(out.String()), nil
351351
}
352352

353353
func (c *DockerCommand) Run(ctx context.Context, options command.RunOptions) error {

pkg/dockerfile/standard_generator.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ func (g *StandardGenerator) pipInstalls() (string, error) {
451451
if torchaudioVersion, ok := g.Config.TorchaudioVersion(); ok {
452452
includePackages = append(includePackages, "torchaudio=="+torchaudioVersion)
453453
}
454+
if tensorflowVersion, ok := g.Config.TensorFlowVersion(); ok {
455+
includePackages = append(includePackages, "tensorflow=="+tensorflowVersion)
456+
}
454457
g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, includePackages)
455458
if err != nil {
456459
return "", err

pkg/requirements/requirements.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"path/filepath"
88
"strings"
99

10-
"github.com/replicate/cog/pkg/util/console"
1110
"github.com/replicate/cog/pkg/util/files"
1211
)
1312

@@ -42,7 +41,6 @@ func CurrentRequirements(tmpDir string) (string, error) {
4241
}
4342

4443
func ReadRequirements(path string) ([]string, error) {
45-
console.Infof("path %s", path)
4644
fh, err := os.Open(path)
4745
if err != nil {
4846
return nil, err

pkg/requirements/requirements_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,76 @@ iopath`), 0o644)
318318
}, requirements)
319319
}
320320

321+
func TestTensorflowRequirements(t *testing.T) {
322+
srcDir := t.TempDir()
323+
reqFile := path.Join(srcDir, ".requirements.txt")
324+
err := os.WriteFile(reqFile, []byte(`compel==2.0.3
325+
diffusers>=0.27.1
326+
gputil==1.4.0
327+
loguru==0.7.2
328+
opencv-python>=4.9.0.80
329+
pillow>=10.2.0
330+
psutil==6.1.1
331+
replicate>=1.0.4
332+
sentry-sdk[fastapi,loguru]>=2.16.0
333+
antialiased_cnns==0.3
334+
beautifulsoup4==4.13.4
335+
imageio==2.37.0
336+
ipdb==0.13.13
337+
kornia==0.8.1
338+
matplotlib==3.10.3
339+
numpy==1.23.5
340+
opencv_python==4.11.0.86
341+
Pillow==11.2.1
342+
pytorch_lightning==2.3.3
343+
PyYAML==6.0.2
344+
Requests==2.32.3
345+
scipy==1.15.3
346+
scikit-image==0.24.0
347+
tensorflow==2.10.0
348+
tensorlayer==2.2.5
349+
tf_slim==1.1.0
350+
timm==1.0.15
351+
torch==2.0.1
352+
torchvision==0.15.2
353+
tqdm==4.67.1`), 0o644)
354+
require.NoError(t, err)
355+
requirements, err := ReadRequirements(reqFile)
356+
require.NoError(t, err)
357+
require.Equal(t, []string{
358+
"compel==2.0.3",
359+
"diffusers>=0.27.1",
360+
"gputil==1.4.0",
361+
"loguru==0.7.2",
362+
"opencv-python>=4.9.0.80",
363+
"pillow>=10.2.0",
364+
"psutil==6.1.1",
365+
"replicate>=1.0.4",
366+
"sentry-sdk[fastapi,loguru]>=2.16.0",
367+
"antialiased_cnns==0.3",
368+
"beautifulsoup4==4.13.4",
369+
"imageio==2.37.0",
370+
"ipdb==0.13.13",
371+
"kornia==0.8.1",
372+
"matplotlib==3.10.3",
373+
"numpy==1.23.5",
374+
"opencv_python==4.11.0.86",
375+
"Pillow==11.2.1",
376+
"pytorch_lightning==2.3.3",
377+
"PyYAML==6.0.2",
378+
"Requests==2.32.3",
379+
"scipy==1.15.3",
380+
"scikit-image==0.24.0",
381+
"tensorflow==2.10.0",
382+
"tensorlayer==2.2.5",
383+
"tf_slim==1.1.0",
384+
"timm==1.0.15",
385+
"torch==2.0.1",
386+
"torchvision==0.15.2",
387+
"tqdm==4.67.1",
388+
}, requirements)
389+
}
390+
321391
func checkRequirements(t *testing.T, expected []string, actual []string) {
322392
t.Helper()
323393
for n, expectLine := range expected {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
compel==2.0.3
2+
diffusers>=0.27.1
3+
gputil==1.4.0
4+
loguru==0.7.2
5+
opencv-python>=4.9.0.80
6+
pillow>=10.2.0
7+
psutil==6.1.1
8+
replicate>=1.0.4
9+
sentry-sdk[fastapi,loguru]>=2.16.0
10+
antialiased_cnns==0.3
11+
beautifulsoup4==4.13.4
12+
imageio==2.37.0
13+
ipdb==0.13.13
14+
kornia==0.8.1
15+
matplotlib==3.10.3
16+
numpy==1.23.5
17+
opencv_python==4.11.0.86
18+
Pillow==11.2.1
19+
pytorch_lightning==2.3.3
20+
PyYAML==6.0.2
21+
Requests==2.32.3
22+
scipy==1.15.3
23+
scikit-image==0.24.0
24+
tensorflow==2.10.0
25+
tensorlayer==2.2.5
26+
tf_slim==1.1.0
27+
timm==1.0.15
28+
torch==2.0.1
29+
torchvision==0.15.2
30+
tqdm==4.67.1
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
build:
2+
gpu: true
3+
cuda: "11.8"
4+
python_version: "3.10"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
- "xvfb"
9+
python_requirements: .requirements.txt
10+
11+
predict: "predict.py:Predictor"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from cog import BasePredictor
2+
3+
import tensorflow
4+
5+
6+
class Predictor(BasePredictor):
7+
def predict(self) -> str:
8+
return tensorflow.__version__

test-integration/test_integration/test_predict.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,22 @@ def test_predict_complex_types_list(docker_image, cog_binary):
598598
)
599599
assert result.returncode == 0
600600
assert result.stdout == "Content: Hi There-I am a test\n"
601+
602+
603+
def test_predict_tensorflow_project(docker_image, cog_binary):
604+
project_dir = Path(__file__).parent / "fixtures/tensorflow-project"
605+
606+
result = subprocess.run(
607+
[
608+
cog_binary,
609+
"predict",
610+
"--debug",
611+
],
612+
cwd=project_dir,
613+
check=True,
614+
capture_output=True,
615+
text=True,
616+
timeout=DEFAULT_TIMEOUT,
617+
)
618+
assert result.returncode == 0
619+
assert result.stdout == "2.10.0\n"

0 commit comments

Comments
 (0)