Skip to content

Commit f325711

Browse files
8W9aGaron
andauthored
Create scanner for removing comments and continuations (#2230)
* Create scanner for removing comments and continuations * When reading the requirements txt, handle line continuations and remove comments * Fix fast_generator test * Update pkg/requirements/requirements_test.go Co-authored-by: Aron Carroll <[email protected]> Signed-off-by: Will Sackfield <[email protected]> * Fix config test * Include beginning of line comment testing * Handle CRLF line breaks * Move no whitespace string logic to scanner * Restore removing whitespace strings above scanner * Scanner exits when the token is nil on return --------- Signed-off-by: Will Sackfield <[email protected]> Co-authored-by: Aron Carroll <[email protected]>
1 parent a8aa011 commit f325711

File tree

6 files changed

+132
-51
lines changed

6 files changed

+132
-51
lines changed

pkg/config/config.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
package config
22

33
import (
4-
"bufio"
54
"encoding/json"
65
"errors"
76
"fmt"
8-
"os"
97
"path"
108
"regexp"
119
"strconv"
1210
"strings"
1311

1412
"gopkg.in/yaml.v2"
1513

14+
"github.com/replicate/cog/pkg/requirements"
1615
"github.com/replicate/cog/pkg/util/console"
1716
"github.com/replicate/cog/pkg/util/slices"
1817
"github.com/replicate/cog/pkg/util/version"
@@ -302,15 +301,10 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
302301

303302
// Load python_requirements into memory to simplify reading it multiple times
304303
if c.Build.PythonRequirements != "" {
305-
fh, err := os.Open(path.Join(projectDir, c.Build.PythonRequirements))
304+
c.Build.pythonRequirementsContent, err = requirements.ReadRequirements(path.Join(projectDir, c.Build.PythonRequirements))
306305
if err != nil {
307306
errs = append(errs, fmt.Errorf("Failed to open python_requirements file: %w", err))
308307
}
309-
// Use scanner to handle CRLF endings
310-
scanner := bufio.NewScanner(fh)
311-
for scanner.Scan() {
312-
c.Build.pythonRequirementsContent = append(c.Build.pythonRequirementsContent, scanner.Text())
313-
}
314308
}
315309

316310
// Backwards compatibility

pkg/config/config_test.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,8 @@ flask>0.4
251251
requirements, err := config.PythonRequirementsForArch("", "", []string{})
252252
require.NoError(t, err)
253253
expected := `foo==1.0.0
254-
# complex requirements
255254
fastapi>=0.6,<1
256255
flask>0.4
257-
# comments!
258-
# blank lines!
259-
260-
# arguments
261256
-f http://example.com`
262257
require.Equal(t, expected, requirements)
263258

pkg/dockerfile/fast_generator.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,15 @@ func (g *FastGenerator) installPython(lines []string, tmpDir string) ([]string,
289289
if err != nil {
290290
return nil, err
291291
}
292-
requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config)
292+
if len(g.Config.Build.PythonPackages) > 0 {
293+
return nil, fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
294+
}
295+
// No Python requirements
296+
if g.Config.Build.PythonRequirements == "" {
297+
return lines, nil
298+
}
299+
300+
requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config.Build.PythonRequirements)
293301
if err != nil {
294302
return nil, err
295303
}

pkg/dockerfile/fast_generator_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ func TestGenerateCUDA(t *testing.T) {
104104
// Create matrix
105105
matrix := MonobaseMatrix{
106106
Id: 1,
107-
CudaVersions: []string{"2.4"},
107+
CudaVersions: []string{"12.4"},
108108
CudnnVersions: []string{"1"},
109109
PythonVersions: []string{"3.9"},
110110
TorchVersions: []string{"2.5.1"},
111111
Venvs: []MonobaseVenv{
112112
{
113113
Python: "3.9",
114114
Torch: "2.5.1",
115-
Cuda: "2.4",
115+
Cuda: "12.4",
116116
},
117117
},
118118
}

pkg/requirements/requirements.go

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,19 @@
11
package requirements
22

33
import (
4+
"bufio"
45
"errors"
5-
"fmt"
66
"os"
77
"path/filepath"
8+
"strings"
89

910
"github.com/replicate/cog/pkg/util/files"
10-
11-
"github.com/replicate/cog/pkg/config"
1211
)
1312

1413
const REQUIREMENTS_FILE = "requirements.txt"
1514

16-
func GenerateRequirements(tmpDir string, cfg *config.Config) (string, error) {
17-
if len(cfg.Build.PythonPackages) > 0 {
18-
return "", fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
19-
}
20-
21-
// No Python requirements
22-
if cfg.Build.PythonRequirements == "" {
23-
return "", nil
24-
}
25-
26-
bs, err := os.ReadFile(cfg.Build.PythonRequirements)
15+
func GenerateRequirements(tmpDir string, path string) (string, error) {
16+
bs, err := os.ReadFile(path)
2717
if err != nil {
2818
return "", err
2919
}
@@ -48,3 +38,65 @@ func CurrentRequirements(tmpDir string) (string, error) {
4838
}
4939
return requirementsFile, nil
5040
}
41+
42+
func ReadRequirements(path string) ([]string, error) {
43+
fh, err := os.Open(path)
44+
if err != nil {
45+
return nil, err
46+
}
47+
// Use scanner to handle CRLF endings
48+
scanner := bufio.NewScanner(fh)
49+
scanner.Split(scanLinesWithContinuations)
50+
requirements := []string{}
51+
for scanner.Scan() {
52+
requirementsText := strings.TrimSpace(scanner.Text())
53+
if len(requirementsText) == 0 {
54+
continue
55+
}
56+
requirements = append(requirements, requirementsText)
57+
}
58+
return requirements, nil
59+
}
60+
61+
func scanLinesWithContinuations(data []byte, atEOF bool) (advance int, token []byte, err error) {
62+
advance = 0
63+
token = nil
64+
err = nil
65+
inHash := false
66+
for {
67+
if atEOF || len(data) == 0 {
68+
break
69+
}
70+
if token == nil {
71+
token = []byte{}
72+
}
73+
if data[advance] == '#' {
74+
inHash = true
75+
}
76+
if data[advance] == '\n' {
77+
shouldAdvance := true
78+
if len(token) > 0 {
79+
if token[len(token)-1] == '\r' && !inHash {
80+
token = token[:len(token)-1]
81+
}
82+
if token[len(token)-1] == '\\' {
83+
if !inHash {
84+
token = token[:len(token)-1]
85+
}
86+
shouldAdvance = false
87+
}
88+
}
89+
if shouldAdvance {
90+
advance++
91+
break
92+
}
93+
} else if !inHash {
94+
token = append(token, data[advance])
95+
}
96+
advance++
97+
if advance == len(data) {
98+
break
99+
}
100+
}
101+
return advance, token, err
102+
}

pkg/requirements/requirements_test.go

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,68 @@ import (
77
"testing"
88

99
"github.com/stretchr/testify/require"
10-
11-
"github.com/replicate/cog/pkg/config"
1210
)
1311

14-
func TestPythonPackages(t *testing.T) {
15-
tmpDir := t.TempDir()
16-
build := config.Build{
17-
PythonPackages: []string{"torch==2.5.1"},
18-
}
19-
config := config.Config{
20-
Build: &build,
21-
}
22-
_, err := GenerateRequirements(tmpDir, &config)
23-
require.ErrorContains(t, err, "python_packages is no longer supported, use python_requirements instead")
24-
}
25-
2612
func TestPythonRequirements(t *testing.T) {
2713
srcDir := t.TempDir()
2814
reqFile := path.Join(srcDir, "requirements.txt")
2915
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
3016
require.NoError(t, err)
3117

32-
build := config.Build{
33-
PythonRequirements: reqFile,
34-
}
35-
config := config.Config{
36-
Build: &build,
37-
}
3818
tmpDir := t.TempDir()
39-
requirementsFile, err := GenerateRequirements(tmpDir, &config)
19+
requirementsFile, err := GenerateRequirements(tmpDir, reqFile)
4020
require.NoError(t, err)
4121
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
4222
}
23+
24+
func TestReadRequirements(t *testing.T) {
25+
srcDir := t.TempDir()
26+
reqFile := path.Join(srcDir, "requirements.txt")
27+
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
28+
require.NoError(t, err)
29+
30+
requirements, err := ReadRequirements(reqFile)
31+
require.NoError(t, err)
32+
require.Equal(t, []string{"torch==2.5.1"}, requirements)
33+
}
34+
35+
func TestReadRequirementsLineContinuations(t *testing.T) {
36+
srcDir := t.TempDir()
37+
reqFile := path.Join(srcDir, "requirements.txt")
38+
err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1\ntorchvision==\\\r\n2.5.1"), 0o644)
39+
require.NoError(t, err)
40+
41+
requirements, err := ReadRequirements(reqFile)
42+
require.NoError(t, err)
43+
require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements)
44+
}
45+
46+
func TestReadRequirementsStripComments(t *testing.T) {
47+
srcDir := t.TempDir()
48+
reqFile := path.Join(srcDir, "requirements.txt")
49+
err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1# Heres my comment\ntorchvision==2.5.1\n# Heres a beginning of line comment"), 0o644)
50+
require.NoError(t, err)
51+
52+
requirements, err := ReadRequirements(reqFile)
53+
require.NoError(t, err)
54+
require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements)
55+
}
56+
57+
func TestReadRequirementsComplex(t *testing.T) {
58+
srcDir := t.TempDir()
59+
reqFile := path.Join(srcDir, "requirements.txt")
60+
err := os.WriteFile(reqFile, []byte(`foo==1.0.0
61+
# complex requirements
62+
fastapi>=0.6,<1
63+
flask>0.4
64+
# comments!
65+
# blank lines!
66+
67+
# arguments
68+
-f http://example.com`), 0o644)
69+
require.NoError(t, err)
70+
71+
requirements, err := ReadRequirements(reqFile)
72+
require.NoError(t, err)
73+
require.Equal(t, []string{"foo==1.0.0", "fastapi>=0.6,<1", "flask>0.4", "-f http://example.com"}, requirements)
74+
}

0 commit comments

Comments
 (0)