diff --git a/pkg/config/config.go b/pkg/config/config.go index ce660f049e..e2e1342a82 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -55,6 +55,7 @@ type Build struct { CUDA string `json:"cuda,omitempty" yaml:"cuda"` CuDNN string `json:"cudnn,omitempty" yaml:"cudnn"` Fast bool `json:"fast,omitempty" yaml:"fast"` + PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides"` pythonRequirementsContent []string } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 9b481b3fec..75f4c46a71 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -708,3 +708,14 @@ build: _, err := FromYAML([]byte(yamlString)) require.NoError(t, err) } + +func TestPythonOverridesConfig(t *testing.T) { + yamlString := ` +build: + python_version: "3.12" + fast: true + python_overrides: "overrides.txt" +` + _, err := FromYAML([]byte(yamlString)) + require.NoError(t, err) +} diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index d172cb0ca0..a9a8b4e68c 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -150,6 +150,11 @@ "$id": "#/properties/build/properties/fast", "type": "boolean", "description": "A flag to enable the experimental fast-push feature from a config level." + }, + "python_overrides": { + "$id": "#/properties/build/properties/python_overrides", + "type": "string", + "description": "A file in the format of pip requirements that specifies python overrides." } }, "additionalProperties": false diff --git a/pkg/dockerfile/fast_generator.go b/pkg/dockerfile/fast_generator.go index 476bd56ecf..4d161e3789 100644 --- a/pkg/dockerfile/fast_generator.go +++ b/pkg/dockerfile/fast_generator.go @@ -349,16 +349,26 @@ func (g *FastGenerator) installPython(lines []string, tmpDir string) ([]string, return lines, nil } - requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config.Build.PythonRequirements) + requirementsFile, err := requirements.GenerateRequirements(tmpDir, g.Config.Build.PythonRequirements, requirements.RequirementsFile) if err != nil { return nil, err } + + overridesFlag := "" + if g.Config.Build.PythonOverrides != "" { + _, err := requirements.GenerateRequirements(tmpDir, g.Config.Build.PythonOverrides, requirements.OverridesFile) + if err != nil { + return nil, err + } + overridesFlag = " --override=/buildtmp/" + requirements.OverridesFile + } + if requirementsFile != "" { lines = append(lines, "RUN "+strings.Join([]string{ "--mount=from=" + dockercontext.RequirementsBuildContextName + ",target=/buildtmp", "--mount=type=bind,src=\".\",target=/src", UV_CACHE_MOUNT, - }, " ")+" cd /src && UV_CACHE_DIR=\""+UV_CACHE_DIR+"\" UV_LINK_MODE=copy UV_COMPILE_BYTECODE=0 /opt/r8/monobase/run.sh monobase.user --requirements=/buildtmp/requirements.txt") + }, " ")+" cd /src && UV_CACHE_DIR=\""+UV_CACHE_DIR+"\" UV_LINK_MODE=copy UV_COMPILE_BYTECODE=0 /opt/r8/monobase/run.sh monobase.user --requirements=/buildtmp/"+requirements.RequirementsFile+overridesFlag) } return lines, nil } diff --git a/pkg/requirements/requirements.go b/pkg/requirements/requirements.go index 458ff17560..1a1d05f861 100644 --- a/pkg/requirements/requirements.go +++ b/pkg/requirements/requirements.go @@ -10,9 +10,10 @@ import ( "github.com/replicate/cog/pkg/util/files" ) -const REQUIREMENTS_FILE = "requirements.txt" +const RequirementsFile = "requirements.txt" +const OverridesFile = "overrides.txt" -func GenerateRequirements(tmpDir string, path string) (string, error) { +func GenerateRequirements(tmpDir string, path string, fileName string) (string, error) { bs, err := os.ReadFile(path) if err != nil { return "", err @@ -20,7 +21,7 @@ func GenerateRequirements(tmpDir string, path string) (string, error) { requirements := string(bs) // Check against the old requirements - requirementsFile := filepath.Join(tmpDir, REQUIREMENTS_FILE) + requirementsFile := filepath.Join(tmpDir, fileName) if err := files.WriteIfDifferent(requirementsFile, requirements); err != nil { return "", err } @@ -28,7 +29,7 @@ func GenerateRequirements(tmpDir string, path string) (string, error) { } func CurrentRequirements(tmpDir string) (string, error) { - requirementsFile := filepath.Join(tmpDir, REQUIREMENTS_FILE) + requirementsFile := filepath.Join(tmpDir, RequirementsFile) _, err := os.Stat(requirementsFile) if err != nil { if errors.Is(err, os.ErrNotExist) { diff --git a/pkg/requirements/requirements_test.go b/pkg/requirements/requirements_test.go index 9918dee4bc..9ab75d2e21 100644 --- a/pkg/requirements/requirements_test.go +++ b/pkg/requirements/requirements_test.go @@ -16,7 +16,7 @@ func TestPythonRequirements(t *testing.T) { require.NoError(t, err) tmpDir := t.TempDir() - requirementsFile, err := GenerateRequirements(tmpDir, reqFile) + requirementsFile, err := GenerateRequirements(tmpDir, reqFile, RequirementsFile) require.NoError(t, err) require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile) } diff --git a/test-integration/test_integration/fixtures/overrides-project/cog.yaml b/test-integration/test_integration/fixtures/overrides-project/cog.yaml new file mode 100644 index 0000000000..73c8a267f0 --- /dev/null +++ b/test-integration/test_integration/fixtures/overrides-project/cog.yaml @@ -0,0 +1,6 @@ +build: + python_version: "3.12" + python_requirements: requirements.txt + fast: true + python_overrides: overrides.txt +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/overrides-project/overrides.txt b/test-integration/test_integration/fixtures/overrides-project/overrides.txt new file mode 100644 index 0000000000..2c04913adb --- /dev/null +++ b/test-integration/test_integration/fixtures/overrides-project/overrides.txt @@ -0,0 +1 @@ +numpy==1.26.4 diff --git a/test-integration/test_integration/fixtures/overrides-project/predict.py b/test-integration/test_integration/fixtures/overrides-project/predict.py new file mode 100644 index 0000000000..c72a3743eb --- /dev/null +++ b/test-integration/test_integration/fixtures/overrides-project/predict.py @@ -0,0 +1,9 @@ +from cog import BasePredictor + +import numpy as np + + +class Predictor(BasePredictor): + + def predict(self) -> str: + return "hello " + np.__version__ diff --git a/test-integration/test_integration/fixtures/overrides-project/requirements.txt b/test-integration/test_integration/fixtures/overrides-project/requirements.txt new file mode 100644 index 0000000000..af471636a2 --- /dev/null +++ b/test-integration/test_integration/fixtures/overrides-project/requirements.txt @@ -0,0 +1 @@ +numpy==1.26.3 diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index 96e8a1e88b..fa4de15a98 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -510,3 +510,15 @@ def test_local_whl_install(docker_image): ) assert build_process.returncode == 0 + + +def test_overrides(docker_image): + project_dir = Path(__file__).parent / "fixtures/overrides-project" + + build_process = subprocess.run( + ["cog", "build", "-t", docker_image], + cwd=project_dir, + capture_output=True, + ) + + assert build_process.returncode == 0 diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 2e26b4753f..15ebdb3fe8 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -434,6 +434,26 @@ def test_predict_optional_project(tmpdir_factory): assert result.stdout == "hello No One\n" +def test_predict_overrides_project(docker_image): + project_dir = Path(__file__).parent / "fixtures/overrides-project" + build_process = subprocess.run( + ["cog", "build", "-t", docker_image], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 0 + result = subprocess.run( + ["cog", "predict", "--debug", docker_image], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + assert result.returncode == 0 + assert result.stdout == "hello 1.26.4\n" + + def test_predict_zsh_package(docker_image): project_dir = Path(__file__).parent / "fixtures/zsh-package" build_process = subprocess.run( @@ -451,7 +471,6 @@ def test_predict_zsh_package(docker_image): text=True, timeout=DEFAULT_TIMEOUT, ) - # stdout should be clean without any log messages so it can be piped to other commands assert result.returncode == 0 assert ",sh," in result.stdout assert ",zsh," in result.stdout