diff --git a/go.mod b/go.mod index 229e5a5720..5901a89987 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeonx/timeago v1.0.0-rc5 - golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/sync v0.13.0 golang.org/x/sys v0.32.0 @@ -271,6 +270,7 @@ require ( go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect + golang.org/x/crypto v0.37.0 // indirect golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.39.0 // indirect @@ -280,7 +280,7 @@ require ( google.golang.org/grpc v1.71.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/gotestsum v1.12.1 // indirect + gotest.tools/gotestsum v1.12.2 // indirect honnef.co/go/tools v0.6.1 // indirect mvdan.cc/gofumpt v0.7.0 // indirect mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect diff --git a/go.sum b/go.sum index bb581f7fbd..1270b5536a 100644 --- a/go.sum +++ b/go.sum @@ -764,8 +764,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/gotestsum v1.12.1 h1:dvcxFBTFR1QsQmrCQa4k/vDXow9altdYz4CjdW+XeBE= -gotest.tools/gotestsum v1.12.1/go.mod h1:mwDmLbx9DIvr09dnAoGgQPLaSXszNpXpWo2bsQge5BE= +gotest.tools/gotestsum v1.12.2 h1:eli4tu9Q2D/ogDsEGSr8XfQfl7mT0JsGOG6DFtUiZ/Q= +gotest.tools/gotestsum v1.12.2/go.mod h1:kjRtCglPZVsSU0hFHX3M5VWBM6Y63emHuB14ER1/sow= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= diff --git a/pkg/config/config.go b/pkg/config/config.go index 70c624a1fe..6ccd271419 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -75,6 +75,9 @@ type Config struct { Predict string `json:"predict,omitempty" yaml:"predict"` Train string `json:"train,omitempty" yaml:"train,omitempty"` Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency,omitempty"` + Environment []string `json:"environment,omitempty" yaml:"environment,omitempty"` + + parsedEnvironment map[string]string } func DefaultConfig() *Config { @@ -319,6 +322,11 @@ func (c *Config) ValidateAndComplete(projectDir string) error { } } + // parse and validate environment variables + if err := c.loadEnvironment(); err != nil { + errs = append(errs, err) + } + if len(errs) > 0 { return errors.Join(errs...) } @@ -577,3 +585,16 @@ func sliceContains(slice []string, s string) bool { } return false } + +func (c *Config) ParsedEnvironment() map[string]string { + return c.parsedEnvironment +} + +func (c *Config) loadEnvironment() error { + env, err := parseAndValidateEnvironment(c.Environment) + if err != nil { + return err + } + c.parsedEnvironment = env + return nil +} diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index a9a8b4e68c..4fbec34f6f 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -207,6 +207,20 @@ "type": "object", "additionalItems": true } + }, + "environment": { + "$id": "#/properties/properties/environment", + "type": [ + "array", + "null" + ], + "description": "A list of environment variables to make available during builds and at runtime, in the format `NAME=value`", + "additionalItems": true, + "items": { + "$id": "#/properties/properties/environment/items", + "type": "string", + "pattern": "^[A-Za-z_][A-Za-z0-9_]*=[^\\s]+$" + } } }, "additionalProperties": false diff --git a/pkg/config/env_variables.go b/pkg/config/env_variables.go new file mode 100644 index 0000000000..3cbcf44d60 --- /dev/null +++ b/pkg/config/env_variables.go @@ -0,0 +1,76 @@ +package config + +import ( + "fmt" + "strings" +) + +// EnvironmentVariableDenyList is a list of environment variable patterns that are +// used internally during build or runtime and thus not allowed to be set by the user. +// There are ways around this restriction, but it's likely to cause unexpected behavior +// and hard to debug issues. So on Cog's predict-build-push happy path, we don't allow +// these to be set. +// This list may change at any time. For more context, see: +// https://github.com/replicate/cog/pull/2274/#issuecomment-2831823185 +var EnvironmentVariableDenyList = []string{ + // paths + "PATH", + "LD_LIBRARY_PATH", + "PYTHONPATH", + "VIRTUAL_ENV", + "PYTHONUNBUFFERED", + // Replicate + "R8_*", + "REPLICATE_*", + // Nvidia + "LIBRARY_PATH", + "CUDA_*", + "NVIDIA_*", + "NV_*", + // pget + "PGET_*", + "HF_ENDPOINT", + "HF_HUB_ENABLE_HF_TRANSFER", + // k8s + "KUBERNETES_*", +} + +// validateEnvName checks if the given environment variable name is allowed. +// Returns an error if the name matches any of the restricted patterns. +func validateEnvName(name string) error { + for _, pattern := range EnvironmentVariableDenyList { + // Check for exact match + if pattern == name { + return fmt.Errorf("environment variable %q is not allowed", name) + } + + // Check for wildcard pattern + if strings.HasSuffix(pattern, "*") { + if strings.HasPrefix(name, pattern[:len(pattern)-1]) { + return fmt.Errorf("environment variable %q is not allowed", name) + } + } + } + return nil +} + +// parseAndValidateEnvironment converts a slice of strings in the format of KEY=VALUE +// to a map[string]string. An error is returned if the format is incorrect or if either +// the variable name or value are invalid. +func parseAndValidateEnvironment(input []string) (map[string]string, error) { + env := map[string]string{} + for _, input := range input { + parts := strings.SplitN(input, "=", 2) + if len(parts) != 2 || parts[0] == "" { + return nil, fmt.Errorf("environment variable %q is not in the KEY=VALUE format", input) + } + if err := validateEnvName(parts[0]); err != nil { + return nil, err + } + if _, ok := env[parts[0]]; ok { + return nil, fmt.Errorf("environment variable %q is already defined", parts[0]) + } + env[parts[0]] = parts[1] + } + return env, nil +} diff --git a/pkg/config/env_variables_test.go b/pkg/config/env_variables_test.go new file mode 100644 index 0000000000..8450891072 --- /dev/null +++ b/pkg/config/env_variables_test.go @@ -0,0 +1,145 @@ +package config + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEnvironmentConfig(t *testing.T) { + t.Run("ParsingValidInput", func(t *testing.T) { + cases := []struct { + Name string + Input []string + Expected map[string]string + }{ + { + Name: "ValidInput", + Input: []string{"NAME=VALUE"}, + Expected: map[string]string{"NAME": "VALUE"}, + }, + { + Name: "ValidInputWithSpaces", + Input: []string{"NAME=VALUE WITH SPACES"}, + Expected: map[string]string{"NAME": "VALUE WITH SPACES"}, + }, + { + Name: "ValidInputWithQuotes", + Input: []string{"NAME=\"VALUE WITH QUOTES\""}, + Expected: map[string]string{"NAME": `"VALUE WITH QUOTES"`}, + }, + { + Name: "DelimitedValue", + Input: []string{"NAME=VALUE1,VALUE2"}, + Expected: map[string]string{"NAME": "VALUE1,VALUE2"}, + }, + { + Name: "EmptyValue", + Input: []string{"NAME="}, + Expected: map[string]string{"NAME": ""}, + }, + { + Name: "EmptyValueWithSpaces", + Input: []string{"NAME= "}, + Expected: map[string]string{"NAME": " "}, + }, + { + Name: "LowerCaseName", + Input: []string{"name=VALUE"}, + Expected: map[string]string{"name": "VALUE"}, + }, + { + Name: "MixedCaseName", + Input: []string{"MiXeD_Case=VALUE"}, + Expected: map[string]string{"MiXeD_Case": "VALUE"}, + }, + { + Name: "EqualSignInValue", + Input: []string{"NAME=VALUE=EQUAL"}, + Expected: map[string]string{"NAME": "VALUE=EQUAL"}, + }, + { + Name: "EqualSignInValueWithSpaces", + Input: []string{"NAME=VALUE=EQUAL WITH SPACES"}, + Expected: map[string]string{"NAME": "VALUE=EQUAL WITH SPACES"}, + }, + { + Name: "MultiLineValue", + Input: []string{"NAME=VALUE1\nVALUE2"}, + Expected: map[string]string{"NAME": "VALUE1\nVALUE2"}, + }, + { + Name: "MultiplePairs", + Input: []string{"NAME1=VALUE1", "NAME2=VALUE2"}, + Expected: map[string]string{"NAME1": "VALUE1", "NAME2": "VALUE2"}, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + parsed, err := parseAndValidateEnvironment(c.Input) + require.NoError(t, err) + require.Equal(t, c.Expected, parsed) + }) + } + }) + + t.Run("ParsingInvalidInput", func(t *testing.T) { + cases := []struct { + Name string + Input []string + ExpectedErrorMessage string + }{ + { + Name: "NameWithoutValue", + Input: []string{"NAME"}, + ExpectedErrorMessage: `environment variable "NAME" is not in the KEY=VALUE format`, + }, + { + Name: "EmptyName", + Input: []string{"=VALUE"}, + ExpectedErrorMessage: `environment variable "=VALUE" is not in the KEY=VALUE format`, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + _, err := parseAndValidateEnvironment(c.Input) + require.Error(t, err) + require.ErrorContains(t, err, c.ExpectedErrorMessage) + }) + } + }) + + t.Run("EnforceDenyList", func(t *testing.T) { + for _, pattern := range EnvironmentVariableDenyList { + // test that exact matches are rejected + t.Run(fmt.Sprintf("Rejects %q", pattern), func(t *testing.T) { + input := fmt.Sprintf("%s=VALUE", pattern) + _, err := parseAndValidateEnvironment([]string{input}) + require.Error(t, err) + require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", pattern)) + }) + + // test that prefix matches are rejected + if strings.HasSuffix(pattern, "*") { + t.Run(fmt.Sprintf("Rejects %q prefix", pattern), func(t *testing.T) { + name := strings.TrimSuffix(pattern, "*") + "SUFFIX" + input := fmt.Sprintf("%s=VALUE", name) + _, err := parseAndValidateEnvironment([]string{input}) + require.Error(t, err) + require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", name)) + }) + } + } + }) + + t.Run("DuplicateNamesAreRejected", func(t *testing.T) { + input := []string{"NAME=VALUE", "NAME=VALUE2"} + _, err := parseAndValidateEnvironment(input) + require.Error(t, err) + require.ErrorContains(t, err, "environment variable \"NAME\" is already defined") + }) +} diff --git a/pkg/dockerfile/env.go b/pkg/dockerfile/env.go new file mode 100644 index 0000000000..bb00a07ef2 --- /dev/null +++ b/pkg/dockerfile/env.go @@ -0,0 +1,23 @@ +package dockerfile + +import ( + "maps" + "slices" + + "github.com/replicate/cog/pkg/config" +) + +func envLineFromConfig(c *config.Config) (string, error) { + vars := c.ParsedEnvironment() + if len(vars) == 0 { + return "", nil + } + + out := "ENV" + for _, name := range slices.Sorted(maps.Keys(vars)) { + out = out + " " + name + "=" + vars[name] + } + out += "\n" + + return out, nil +} diff --git a/pkg/dockerfile/fast_generator.go b/pkg/dockerfile/fast_generator.go index 4ad0d5607d..6ad7939276 100644 --- a/pkg/dockerfile/fast_generator.go +++ b/pkg/dockerfile/fast_generator.go @@ -431,6 +431,14 @@ func (g *FastGenerator) installSrc(lines []string, weights []weights.Weight) ([] } func (g *FastGenerator) entrypoint(lines []string) ([]string, error) { + line, err := envLineFromConfig(g.Config) + if err != nil { + return nil, err + } + if line != "" { + lines = append(lines, line) + } + return append(lines, []string{ "WORKDIR /src", "ENV VERBOSE=0", diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index 22b6b7c33a..fdd5d2c598 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -143,6 +143,10 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e if err != nil { return "", err } + envs, err := g.envVars() + if err != nil { + return "", err + } runCommands, err := g.runCommands() if err != nil { return "", err @@ -160,6 +164,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e steps := []string{ "#syntax=docker/dockerfile:1.4", "FROM " + baseImage, + envs, aptInstalls, installCog, pipInstalls, @@ -177,6 +182,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e "FROM " + baseImage, g.preamble(), g.installTini(), + envs, aptInstalls, installPython, pipInstalls, @@ -505,6 +511,10 @@ This is the offending line: %s`, command) return strings.Join(lines, "\n"), nil } +func (g *StandardGenerator) envVars() (string, error) { + return envLineFromConfig(g.Config) +} + // writeTemp writes a temporary file that can be used as part of the build process // It returns the lines to add to Dockerfile to make it available and the filename it ends up as inside the container func (g *StandardGenerator) writeTemp(filename string, contents []byte) ([]string, string, error) { diff --git a/test-integration/test_integration/conftest.py b/test-integration/test_integration/conftest.py index 7a6f4f87da..9ee583b6fe 100644 --- a/test-integration/test_integration/conftest.py +++ b/test-integration/test_integration/conftest.py @@ -1,21 +1,50 @@ import os +import shutil +from pathlib import Path import pytest +from _pytest.config import Config +from _pytest.main import Session from .util import random_string, remove_docker_image -def pytest_sessionstart(session): +def pytest_sessionstart(session: Session) -> None: os.environ["COG_NO_UPDATE_CHECK"] = "1" @pytest.fixture -def docker_image_name(): +def cog_binary(pytestconfig: Config) -> Path: + """Get the path to the cog binary used in integration tests.""" + if os.environ.get("COG_BINARY"): + cog_path = Path(os.environ["COG_BINARY"]) + if not cog_path.is_absolute(): + # Only make relative to rootdir if it's a relative path + rootdir = Path(pytestconfig.rootdir) + cog_path = rootdir / cog_path + return cog_path.resolve() + + # Check if cog exists in project root. + # this is where integration tests dump the test build + project_cog = Path(pytestconfig.rootdir) / "cog" + if project_cog.exists(): + return project_cog + + # Fall back to cog in PATH + cog_path = shutil.which("cog") + if cog_path: + return Path(cog_path) + + raise FileNotFoundError("Could not find cog binary") + + +@pytest.fixture +def docker_image_name() -> str: return "cog-test-" + random_string(10) @pytest.fixture -def docker_image(docker_image_name): +def docker_image(docker_image_name: str) -> str: yield docker_image_name # We expect the image to exist by this point and will fail if it doesn't. # If you just need a name, use docker_image_name. diff --git a/test-integration/test_integration/fixtures/env-project/cog.yaml b/test-integration/test_integration/fixtures/env-project/cog.yaml new file mode 100644 index 0000000000..4827f35ccb --- /dev/null +++ b/test-integration/test_integration/fixtures/env-project/cog.yaml @@ -0,0 +1,4 @@ +predict: "predict.py:Predictor" +environment: + - NAME=michael + - TEST_VAR=test_value diff --git a/test-integration/test_integration/fixtures/env-project/predict.py b/test-integration/test_integration/fixtures/env-project/predict.py new file mode 100644 index 0000000000..6029e472c4 --- /dev/null +++ b/test-integration/test_integration/fixtures/env-project/predict.py @@ -0,0 +1,7 @@ +from cog import BasePredictor +import os + + +class Predictor(BasePredictor): + def predict(self, name: str) -> str: + return f"ENV[{name}]={os.getenv(name)}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index df4a7a03ef..282faafc3f 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -555,3 +555,24 @@ def test_predict_fast_build(docker_image): ) assert result.returncode == 0 assert result.stdout == "hello world\n" + + +def test_predict_env_vars(docker_image, cog_binary): + project_dir = Path(__file__).parent / "fixtures/env-project" + build_process = subprocess.run( + [cog_binary, "build", "-t", docker_image], + cwd=project_dir, + capture_output=True, + ) + assert build_process.returncode == 0 + result = subprocess.run( + [cog_binary, "predict", "--debug", docker_image, "-i", "name=TEST_VAR"], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + assert result.returncode == 0 + assert result.stdout == "ENV[TEST_VAR]=test_value\n" + diff --git a/tox.ini b/tox.ini index 63f22d94de..e9060dae47 100644 --- a/tox.ini +++ b/tox.ini @@ -71,4 +71,6 @@ deps = pytest-rerunfailures pytest-timeout pytest-xdist +pass_env = + COG_BINARY commands = pytest {posargs:-n auto -vv --reruns 3}