Skip to content
Merged
10 changes: 2 additions & 8 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package config

import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"path"
"regexp"
"strconv"
"strings"

"gopkg.in/yaml.v2"

"github.com/replicate/cog/pkg/requirements"
"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/slices"
"github.com/replicate/cog/pkg/util/version"
Expand Down Expand Up @@ -302,15 +301,10 @@ func (c *Config) ValidateAndComplete(projectDir string) error {

// Load python_requirements into memory to simplify reading it multiple times
if c.Build.PythonRequirements != "" {
fh, err := os.Open(path.Join(projectDir, c.Build.PythonRequirements))
c.Build.pythonRequirementsContent, err = requirements.ReadRequirements(path.Join(projectDir, c.Build.PythonRequirements))
if err != nil {
errs = append(errs, fmt.Errorf("Failed to open python_requirements file: %w", err))
}
// Use scanner to handle CRLF endings
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
c.Build.pythonRequirementsContent = append(c.Build.pythonRequirementsContent, scanner.Text())
}
}

// Backwards compatibility
Expand Down
5 changes: 0 additions & 5 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,8 @@ flask>0.4
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := `foo==1.0.0
# complex requirements
fastapi>=0.6,<1
flask>0.4
# comments!
# blank lines!

# arguments
-f http://example.com`
require.Equal(t, expected, requirements)

Expand Down
10 changes: 9 additions & 1 deletion pkg/dockerfile/fast_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,15 @@ func (g *FastGenerator) installPython(lines []string, tmpDir string) ([]string,
if err != nil {
return nil, err
}
requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config)
if len(g.Config.Build.PythonPackages) > 0 {
return nil, fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
}
// No Python requirements
if g.Config.Build.PythonRequirements == "" {
return lines, nil
}

requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config.Build.PythonRequirements)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/dockerfile/fast_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ func TestGenerateCUDA(t *testing.T) {
// Create matrix
matrix := MonobaseMatrix{
Id: 1,
CudaVersions: []string{"2.4"},
CudaVersions: []string{"12.4"},
CudnnVersions: []string{"1"},
PythonVersions: []string{"3.9"},
TorchVersions: []string{"2.5.1"},
Venvs: []MonobaseVenv{
{
Python: "3.9",
Torch: "2.5.1",
Cuda: "2.4",
Cuda: "12.4",
},
},
}
Expand Down
80 changes: 66 additions & 14 deletions pkg/requirements/requirements.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
package requirements

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/replicate/cog/pkg/util/files"

"github.com/replicate/cog/pkg/config"
)

const REQUIREMENTS_FILE = "requirements.txt"

func GenerateRequirements(tmpDir string, cfg *config.Config) (string, error) {
if len(cfg.Build.PythonPackages) > 0 {
return "", fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
}

// No Python requirements
if cfg.Build.PythonRequirements == "" {
return "", nil
}

bs, err := os.ReadFile(cfg.Build.PythonRequirements)
func GenerateRequirements(tmpDir string, path string) (string, error) {
bs, err := os.ReadFile(path)
if err != nil {
return "", err
}
Expand All @@ -48,3 +38,65 @@ func CurrentRequirements(tmpDir string) (string, error) {
}
return requirementsFile, nil
}

func ReadRequirements(path string) ([]string, error) {
fh, err := os.Open(path)
if err != nil {
return nil, err
}
// Use scanner to handle CRLF endings
scanner := bufio.NewScanner(fh)
scanner.Split(scanLinesWithContinuations)
requirements := []string{}
for scanner.Scan() {
requirementsText := strings.TrimSpace(scanner.Text())
if len(requirementsText) == 0 {
continue
}
requirements = append(requirements, requirementsText)
}
return requirements, nil
}

func scanLinesWithContinuations(data []byte, atEOF bool) (advance int, token []byte, err error) {
advance = 0
token = nil
err = nil
inHash := false
for {
if atEOF || len(data) == 0 {
break
}
if token == nil {
token = []byte{}
}
if data[advance] == '#' {
inHash = true
}
if data[advance] == '\n' {
shouldAdvance := true
if len(token) > 0 {
if token[len(token)-1] == '\r' && !inHash {
token = token[:len(token)-1]
}
if token[len(token)-1] == '\\' {
if !inHash {
token = token[:len(token)-1]
}
shouldAdvance = false
}
}
if shouldAdvance {
advance++
break
}
} else if !inHash {
token = append(token, data[advance])
}
advance++
if advance == len(data) {
break
}
}
return advance, token, err
}
74 changes: 53 additions & 21 deletions pkg/requirements/requirements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,68 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/replicate/cog/pkg/config"
)

func TestPythonPackages(t *testing.T) {
tmpDir := t.TempDir()
build := config.Build{
PythonPackages: []string{"torch==2.5.1"},
}
config := config.Config{
Build: &build,
}
_, err := GenerateRequirements(tmpDir, &config)
require.ErrorContains(t, err, "python_packages is no longer supported, use python_requirements instead")
}

func TestPythonRequirements(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
require.NoError(t, err)

build := config.Build{
PythonRequirements: reqFile,
}
config := config.Config{
Build: &build,
}
tmpDir := t.TempDir()
requirementsFile, err := GenerateRequirements(tmpDir, &config)
requirementsFile, err := GenerateRequirements(tmpDir, reqFile)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
}

func TestReadRequirements(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
require.NoError(t, err)

requirements, err := ReadRequirements(reqFile)
require.NoError(t, err)
require.Equal(t, []string{"torch==2.5.1"}, requirements)
}

func TestReadRequirementsLineContinuations(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1\ntorchvision==\\\r\n2.5.1"), 0o644)
require.NoError(t, err)

requirements, err := ReadRequirements(reqFile)
require.NoError(t, err)
require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements)
}

func TestReadRequirementsStripComments(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1# Heres my comment\ntorchvision==2.5.1\n# Heres a beginning of line comment"), 0o644)
require.NoError(t, err)

requirements, err := ReadRequirements(reqFile)
require.NoError(t, err)
require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements)
}

func TestReadRequirementsComplex(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte(`foo==1.0.0
# complex requirements
fastapi>=0.6,<1
flask>0.4
# comments!
# blank lines!

# arguments
-f http://example.com`), 0o644)
require.NoError(t, err)

requirements, err := ReadRequirements(reqFile)
require.NoError(t, err)
require.Equal(t, []string{"foo==1.0.0", "fastapi>=0.6,<1", "flask>0.4", "-f http://example.com"}, requirements)
}
Loading