Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/vincent-petithory/dataurl v1.0.0
github.com/xeipuuv/gojsonschema v1.2.0
github.com/xeonx/timeago v1.0.0-rc5
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
Expand Down Expand Up @@ -271,6 +270,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/net v0.39.0 // indirect
Expand All @@ -280,7 +280,7 @@ require (
google.golang.org/grpc v1.71.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/gotestsum v1.12.1 // indirect
gotest.tools/gotestsum v1.12.2 // indirect
honnef.co/go/tools v0.6.1 // indirect
mvdan.cc/gofumpt v0.7.0 // indirect
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/gotestsum v1.12.1 h1:dvcxFBTFR1QsQmrCQa4k/vDXow9altdYz4CjdW+XeBE=
gotest.tools/gotestsum v1.12.1/go.mod h1:mwDmLbx9DIvr09dnAoGgQPLaSXszNpXpWo2bsQge5BE=
gotest.tools/gotestsum v1.12.2 h1:eli4tu9Q2D/ogDsEGSr8XfQfl7mT0JsGOG6DFtUiZ/Q=
gotest.tools/gotestsum v1.12.2/go.mod h1:kjRtCglPZVsSU0hFHX3M5VWBM6Y63emHuB14ER1/sow=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI=
Expand Down
21 changes: 21 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ type Config struct {
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train,omitempty"`
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
Environment []string `json:"environment,omitempty" yaml:"environment,omitempty"`

parsedEnvironment map[string]string
}

func DefaultConfig() *Config {
Expand Down Expand Up @@ -319,6 +322,11 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
}
}

// parse and validate environment variables
if err := c.loadEnvironment(); err != nil {
errs = append(errs, err)
}

if len(errs) > 0 {
return errors.Join(errs...)
}
Expand Down Expand Up @@ -577,3 +585,16 @@ func sliceContains(slice []string, s string) bool {
}
return false
}

func (c *Config) ParsedEnvironment() map[string]string {
return c.parsedEnvironment
}

func (c *Config) loadEnvironment() error {
env, err := parseAndValidateEnvironment(c.Environment)
if err != nil {
return err
}
c.parsedEnvironment = env
return nil
}
14 changes: 14 additions & 0 deletions pkg/config/data/config_schema_v1.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@
"type": "object",
"additionalItems": true
}
},
"environment": {
"$id": "#/properties/properties/environment",
"type": [
"array",
"null"
],
"description": "A list of environment variables to make available during builds and at runtime, in the format `NAME=value`",
"additionalItems": true,
"items": {
"$id": "#/properties/properties/environment/items",
"type": "string",
"pattern": "^[A-Za-z_][A-Za-z0-9_]*=[^\\s]+$"
}
}
},
"additionalProperties": false
Expand Down
76 changes: 76 additions & 0 deletions pkg/config/env_variables.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package config

import (
"fmt"
"strings"
)

// EnvironmentVariableDenyList is a list of environment variable patterns that are
// used internally during build or runtime and thus not allowed to be set by the user.
// There are ways around this restriction, but it's likely to cause unexpected behavior
// and hard to debug issues. So on Cog's predict-build-push happy path, we don't allow
// these to be set.
// This list may change at any time. For more context, see:
// https://github.com/replicate/cog/pull/2274/#issuecomment-2831823185
var EnvironmentVariableDenyList = []string{
// paths
"PATH",
"LD_LIBRARY_PATH",
"PYTHONPATH",
"VIRTUAL_ENV",
"PYTHONUNBUFFERED",
// Replicate
"R8_*",
"REPLICATE_*",
// Nvidia
"LIBRARY_PATH",
"CUDA_*",
"NVIDIA_*",
"NV_*",
// pget
"PGET_*",
"HF_ENDPOINT",
"HF_HUB_ENABLE_HF_TRANSFER",
// k8s
"KUBERNETES_*",
}

// validateEnvName checks if the given environment variable name is allowed.
// Returns an error if the name matches any of the restricted patterns.
func validateEnvName(name string) error {
for _, pattern := range EnvironmentVariableDenyList {
// Check for exact match
if pattern == name {
return fmt.Errorf("environment variable %q is not allowed", name)
}

// Check for wildcard pattern
if strings.HasSuffix(pattern, "*") {
if strings.HasPrefix(name, pattern[:len(pattern)-1]) {
return fmt.Errorf("environment variable %q is not allowed", name)
}
}
}
return nil
}

// parseAndValidateEnvironment converts a slice of strings in the format of KEY=VALUE
// to a map[string]string. An error is returned if the format is incorrect or if either
// the variable name or value are invalid.
func parseAndValidateEnvironment(input []string) (map[string]string, error) {
env := map[string]string{}
for _, input := range input {
parts := strings.SplitN(input, "=", 2)
if len(parts) != 2 || parts[0] == "" {
return nil, fmt.Errorf("environment variable %q is not in the KEY=VALUE format", input)
}
if err := validateEnvName(parts[0]); err != nil {
return nil, err
}
if _, ok := env[parts[0]]; ok {
return nil, fmt.Errorf("environment variable %q is already defined", parts[0])
}
env[parts[0]] = parts[1]
}
return env, nil
}
145 changes: 145 additions & 0 deletions pkg/config/env_variables_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package config

import (
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestEnvironmentConfig(t *testing.T) {
t.Run("ParsingValidInput", func(t *testing.T) {
cases := []struct {
Name string
Input []string
Expected map[string]string
}{
{
Name: "ValidInput",
Input: []string{"NAME=VALUE"},
Expected: map[string]string{"NAME": "VALUE"},
},
{
Name: "ValidInputWithSpaces",
Input: []string{"NAME=VALUE WITH SPACES"},
Expected: map[string]string{"NAME": "VALUE WITH SPACES"},
},
{
Name: "ValidInputWithQuotes",
Input: []string{"NAME=\"VALUE WITH QUOTES\""},
Expected: map[string]string{"NAME": `"VALUE WITH QUOTES"`},
},
{
Name: "DelimitedValue",
Input: []string{"NAME=VALUE1,VALUE2"},
Expected: map[string]string{"NAME": "VALUE1,VALUE2"},
},
{
Name: "EmptyValue",
Input: []string{"NAME="},
Expected: map[string]string{"NAME": ""},
},
{
Name: "EmptyValueWithSpaces",
Input: []string{"NAME= "},
Expected: map[string]string{"NAME": " "},
},
{
Name: "LowerCaseName",
Input: []string{"name=VALUE"},
Expected: map[string]string{"name": "VALUE"},
},
{
Name: "MixedCaseName",
Input: []string{"MiXeD_Case=VALUE"},
Expected: map[string]string{"MiXeD_Case": "VALUE"},
},
{
Name: "EqualSignInValue",
Input: []string{"NAME=VALUE=EQUAL"},
Expected: map[string]string{"NAME": "VALUE=EQUAL"},
},
{
Name: "EqualSignInValueWithSpaces",
Input: []string{"NAME=VALUE=EQUAL WITH SPACES"},
Expected: map[string]string{"NAME": "VALUE=EQUAL WITH SPACES"},
},
{
Name: "MultiLineValue",
Input: []string{"NAME=VALUE1\nVALUE2"},
Expected: map[string]string{"NAME": "VALUE1\nVALUE2"},
},
{
Name: "MultiplePairs",
Input: []string{"NAME1=VALUE1", "NAME2=VALUE2"},
Expected: map[string]string{"NAME1": "VALUE1", "NAME2": "VALUE2"},
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
parsed, err := parseAndValidateEnvironment(c.Input)
require.NoError(t, err)
require.Equal(t, c.Expected, parsed)
})
}
})

t.Run("ParsingInvalidInput", func(t *testing.T) {
cases := []struct {
Name string
Input []string
ExpectedErrorMessage string
}{
{
Name: "NameWithoutValue",
Input: []string{"NAME"},
ExpectedErrorMessage: `environment variable "NAME" is not in the KEY=VALUE format`,
},
{
Name: "EmptyName",
Input: []string{"=VALUE"},
ExpectedErrorMessage: `environment variable "=VALUE" is not in the KEY=VALUE format`,
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
_, err := parseAndValidateEnvironment(c.Input)
require.Error(t, err)
require.ErrorContains(t, err, c.ExpectedErrorMessage)
})
}
})

t.Run("EnforceDenyList", func(t *testing.T) {
for _, pattern := range EnvironmentVariableDenyList {
// test that exact matches are rejected
t.Run(fmt.Sprintf("Rejects %q", pattern), func(t *testing.T) {
input := fmt.Sprintf("%s=VALUE", pattern)
_, err := parseAndValidateEnvironment([]string{input})
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", pattern))
})

// test that prefix matches are rejected
if strings.HasSuffix(pattern, "*") {
t.Run(fmt.Sprintf("Rejects %q prefix", pattern), func(t *testing.T) {
name := strings.TrimSuffix(pattern, "*") + "SUFFIX"
input := fmt.Sprintf("%s=VALUE", name)
_, err := parseAndValidateEnvironment([]string{input})
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", name))
})
}
}
})

t.Run("DuplicateNamesAreRejected", func(t *testing.T) {
input := []string{"NAME=VALUE", "NAME=VALUE2"}
_, err := parseAndValidateEnvironment(input)
require.Error(t, err)
require.ErrorContains(t, err, "environment variable \"NAME\" is already defined")
})
}
23 changes: 23 additions & 0 deletions pkg/dockerfile/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dockerfile

import (
"maps"
"slices"

"github.com/replicate/cog/pkg/config"
)

func envLineFromConfig(c *config.Config) (string, error) {
vars := c.ParsedEnvironment()
if len(vars) == 0 {
return "", nil
}

out := "ENV"
for _, name := range slices.Sorted(maps.Keys(vars)) {
out = out + " " + name + "=" + vars[name]
}
out += "\n"
Comment on lines +16 to +20
Copy link

Copilot AI May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider using a strings.Builder to construct the ENV line more efficiently for better readability and potential performance improvements.

Suggested change
out := "ENV"
for _, name := range slices.Sorted(maps.Keys(vars)) {
out = out + " " + name + "=" + vars[name]
}
out += "\n"
var out strings.Builder
out.WriteString("ENV")
for _, name := range slices.Sorted(maps.Keys(vars)) {
out.WriteString(" ")
out.WriteString(name)
out.WriteString("=")
out.WriteString(vars[name])
}
out.WriteString("\n")

Copilot uses AI. Check for mistakes.

return out, nil
}
8 changes: 8 additions & 0 deletions pkg/dockerfile/fast_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ func (g *FastGenerator) installSrc(lines []string, weights []weights.Weight) ([]
}

func (g *FastGenerator) entrypoint(lines []string) ([]string, error) {
line, err := envLineFromConfig(g.Config)
if err != nil {
return nil, err
}
if line != "" {
lines = append(lines, line)
}

return append(lines, []string{
"WORKDIR /src",
"ENV VERBOSE=0",
Expand Down
10 changes: 10 additions & 0 deletions pkg/dockerfile/standard_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
if err != nil {
return "", err
}
envs, err := g.envVars()
if err != nil {
return "", err
}
runCommands, err := g.runCommands()
if err != nil {
return "", err
Expand All @@ -160,6 +164,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
steps := []string{
"#syntax=docker/dockerfile:1.4",
"FROM " + baseImage,
envs,
aptInstalls,
installCog,
pipInstalls,
Expand All @@ -177,6 +182,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
"FROM " + baseImage,
g.preamble(),
g.installTini(),
envs,
aptInstalls,
installPython,
pipInstalls,
Expand Down Expand Up @@ -505,6 +511,10 @@ This is the offending line: %s`, command)
return strings.Join(lines, "\n"), nil
}

func (g *StandardGenerator) envVars() (string, error) {
return envLineFromConfig(g.Config)
}

// writeTemp writes a temporary file that can be used as part of the build process
// It returns the lines to add to Dockerfile to make it available and the filename it ends up as inside the container
func (g *StandardGenerator) writeTemp(filename string, contents []byte) ([]string, string, error) {
Expand Down
Loading