diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 2748ffd512..96b98cc6c3 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -29,6 +29,7 @@ var buildStrip bool var buildPrecompile bool var buildFast bool var buildLocalImage bool +var configFilename string const useCogBaseImageFlagKey = "use-cog-base-image" @@ -53,6 +54,7 @@ func newBuildCommand() *cobra.Command { addPrecompileFlag(cmd) addFastFlag(cmd) addLocalImage(cmd) + addConfigFlag(cmd) cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'") return cmd } @@ -68,7 +70,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { logClient := coglog.NewClient(client) logCtx := logClient.StartBuild(buildFast, buildLocalImage) - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { logClient.EndBuild(ctx, err, logCtx) return err @@ -172,6 +174,11 @@ func addLocalImage(cmd *cobra.Command) { _ = cmd.Flags().MarkHidden(localImage) } +func addConfigFlag(cmd *cobra.Command) { + const configFlag = "f" + cmd.Flags().StringVar(&configFilename, configFlag, "cog.yaml", "The name of the config file.") +} + func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error { flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"} var flagsSet []string diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 3b5eccc1e4..0be4d36f7b 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -8,7 +8,6 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/dockerfile" - "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" ) @@ -18,7 +17,7 @@ func newDebugCommand() *cobra.Command { cmd := &cobra.Command{ Use: "debug", Hidden: true, - Short: "Generate a Dockerfile from " + global.ConfigFilename, + Short: "Generate a Dockerfile from cog", RunE: cmdDockerfile, } @@ -29,6 +28,7 @@ func newDebugCommand() *cobra.Command { addBuildTimestampFlag(cmd) addFastFlag(cmd) addLocalImage(cmd) + addConfigFlag(cmd) cmd.Flags().StringVarP(&imageName, "image-name", "", "", "The image name to use for the generated Dockerfile") return cmd @@ -37,7 +37,7 @@ func newDebugCommand() *cobra.Command { func cmdDockerfile(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } diff --git a/pkg/cli/migrate.go b/pkg/cli/migrate.go index ac794e3abf..b517ae7e81 100644 --- a/pkg/cli/migrate.go +++ b/pkg/cli/migrate.go @@ -21,6 +21,7 @@ This will attempt to migrate your cog project to be compatible with fast boots.` } addYesFlag(cmd) + addConfigFlag(cmd) return cmd } @@ -31,7 +32,7 @@ func cmdMigrate(cmd *cobra.Command, args []string) error { if err != nil { return err } - err = migrator.Migrate(ctx) + err = migrator.Migrate(ctx, configFilename) if err != nil { return err } diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 2018c3bc4d..54a5dc1d63 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -58,6 +58,7 @@ the prediction on that.`, addSetupTimeoutFlag(cmd) addFastFlag(cmd) addLocalImage(cmd) + addConfigFlag(cmd) cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path") @@ -78,7 +79,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { if len(args) == 0 { // Build image - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 651879e3da..bfb619abb6 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -39,6 +39,7 @@ func newPushCommand() *cobra.Command { addPrecompileFlag(cmd) addFastFlag(cmd) addLocalImage(cmd) + addConfigFlag(cmd) return cmd } @@ -54,7 +55,7 @@ func push(cmd *cobra.Command, args []string) error { logClient := coglog.NewClient(client) logCtx := logClient.StartPush(buildFast, buildLocalImage) - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { logClient.EndPush(ctx, err, logCtx) return err diff --git a/pkg/cli/run.go b/pkg/cli/run.go index 9cce45f0b3..9d638f8413 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -38,6 +38,7 @@ func newRunCommand() *cobra.Command { addGpusFlag(cmd) addFastFlag(cmd) addLocalImage(cmd) + addConfigFlag(cmd) flags := cmd.Flags() // Flags after first argument are considered args and passed to command @@ -54,7 +55,7 @@ func newRunCommand() *cobra.Command { func run(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index 9857071de0..a525b5f2cd 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -34,6 +34,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` addUseCogBaseImageFlag(cmd) addGpusFlag(cmd) addFastFlag(cmd) + addConfigFlag(cmd) cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen") @@ -43,7 +44,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 7b239dd267..743b1e59b0 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -44,6 +44,7 @@ Otherwise, it will build the model in the current directory and train it.`, addGpusFlag(cmd) addUseCogBaseImageFlag(cmd) addFastFlag(cmd) + addConfigFlag(cmd) cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") cmd.Flags().StringArrayVarP(&trainEnvFlags, "env", "e", []string{}, "Environment variables, in the form name=value") @@ -61,7 +62,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { volumes := []docker.Volume{} gpus := gpusFlag - cfg, projectDir, err := config.GetConfig() + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } diff --git a/pkg/config/load.go b/pkg/config/load.go index c7ba9a5937..3673e051bd 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -7,30 +7,30 @@ import ( "path/filepath" "github.com/replicate/cog/pkg/errors" - "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/files" ) const maxSearchDepth = 100 // Returns the project's root directory, or the directory specified by the --project-dir flag -func GetProjectDir() (string, error) { +func GetProjectDir(configFilename string) (string, error) { cwd, err := os.Getwd() if err != nil { return "", err } - return findProjectRootDir(cwd) + return findProjectRootDir(cwd, configFilename) } // Loads and instantiates a Config object // customDir can be specified to override the default - current working directory -func GetConfig() (*Config, string, error) { +func GetConfig(configFilename string) (*Config, string, error) { // Find the root project directory - rootDir, err := GetProjectDir() + rootDir, err := GetProjectDir(configFilename) + if err != nil { return nil, "", err } - configPath := path.Join(rootDir, global.ConfigFilename) + configPath := path.Join(rootDir, configFilename) // Then try to load the config file from there config, err := loadConfigFromFile(configPath) @@ -51,7 +51,7 @@ func loadConfigFromFile(file string) (*Config, error) { } if !exists { - return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, filepath.Dir(file)) + return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", filepath.Base(file), filepath.Dir(file)) } contents, err := os.ReadFile(file) @@ -69,8 +69,8 @@ func loadConfigFromFile(file string) (*Config, error) { } // Given a directory, find the cog config file in that directory -func findConfigPathInDirectory(dir string) (configPath string, err error) { - filePath := path.Join(dir, global.ConfigFilename) +func findConfigPathInDirectory(dir string, configFilename string) (configPath string, err error) { + filePath := path.Join(dir, configFilename) exists, err := files.Exists(filePath) if err != nil { return "", fmt.Errorf("Failed to scan directory %s for %s: %s", dir, filePath, err) @@ -78,21 +78,21 @@ func findConfigPathInDirectory(dir string) (configPath string, err error) { return filePath, nil } - return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", global.ConfigFilename, dir)) + return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", configFilename, dir)) } // Walk up the directory tree to find the root of the project. // The project root is defined as the directory housing a `cog.yaml` file. -func findProjectRootDir(startDir string) (string, error) { +func findProjectRootDir(startDir string, configFilename string) (string, error) { dir := startDir for i := 0; i < maxSearchDepth; i++ { - switch _, err := findConfigPathInDirectory(dir); { + switch _, err := findConfigPathInDirectory(dir, configFilename); { case err != nil && !errors.IsConfigNotFound(err): return "", err case err == nil: return dir, nil case dir == "." || dir == "/": - return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", global.ConfigFilename, startDir)) + return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", configFilename, startDir)) } dir = filepath.Dir(dir) diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index 78d15904b6..7445a373f8 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -28,7 +28,7 @@ func TestFindProjectRootDirShouldFindParentDir(t *testing.T) { err = os.MkdirAll(subdir, 0o700) require.NoError(t, err) - foundDir, err := findProjectRootDir(subdir) + foundDir, err := findProjectRootDir(subdir, "cog.yaml") require.NoError(t, err) require.Equal(t, foundDir, projectDir) } @@ -40,6 +40,6 @@ func TestFindProjectRootDirShouldReturnErrIfNoConfig(t *testing.T) { err := os.MkdirAll(subdir, 0o700) require.NoError(t, err) - _, err = findProjectRootDir(subdir) + _, err = findProjectRootDir(subdir, "cog.yaml") require.Error(t, err) } diff --git a/pkg/global/global.go b/pkg/global/global.go index 2b90bb930f..0813de445d 100644 --- a/pkg/global/global.go +++ b/pkg/global/global.go @@ -6,7 +6,6 @@ var ( BuildTime = "none" Debug = false ProfilingEnabled = false - ConfigFilename = "cog.yaml" ReplicateRegistryHost = "r8.im" ReplicateWebsiteHost = "replicate.com" LabelNamespace = "run.cog." diff --git a/pkg/migrate/migrator.go b/pkg/migrate/migrator.go index c1faa96629..e37121f2a4 100644 --- a/pkg/migrate/migrator.go +++ b/pkg/migrate/migrator.go @@ -3,5 +3,5 @@ package migrate import "context" type Migrator interface { - Migrate(ctx context.Context) error + Migrate(ctx context.Context, configFilename string) error } diff --git a/pkg/migrate/migrator_v1_v1fast.go b/pkg/migrate/migrator_v1_v1fast.go index e348a7549a..94a524a64a 100644 --- a/pkg/migrate/migrator_v1_v1fast.go +++ b/pkg/migrate/migrator_v1_v1fast.go @@ -16,7 +16,6 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/dockerfile" - "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/requirements" "github.com/replicate/cog/pkg/util" "github.com/replicate/cog/pkg/util/console" @@ -40,8 +39,8 @@ func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast { } } -func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error { - cfg, projectDir, err := config.GetConfig() +func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context, configFilename string) error { + cfg, projectDir, err := config.GetConfig(configFilename) if err != nil { return err } @@ -57,7 +56,7 @@ func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error { if err != nil { return err } - err = g.flushConfig(cfg, projectDir) + err = g.flushConfig(cfg, projectDir, configFilename) return err } @@ -167,7 +166,7 @@ func (g *MigratorV1ToV1Fast) checkPythonCode(ctx context.Context, cfg *config.Co return nil } -func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error { +func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configFilename string) error { if cfg.Build == nil { cfg.Build = config.DefaultConfig().Build } @@ -182,7 +181,7 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error { } configStr := string(data) - configFilepath := filepath.Join(dir, global.ConfigFilename) + configFilepath := filepath.Join(dir, configFilename) file, err := os.Open(configFilepath) if err != nil { return err diff --git a/pkg/migrate/migrator_v1_v1fast_test.go b/pkg/migrate/migrator_v1_v1fast_test.go index a3f8bfb570..cd85623724 100644 --- a/pkg/migrate/migrator_v1_v1fast_test.go +++ b/pkg/migrate/migrator_v1_v1fast_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/requirements" ) @@ -25,7 +24,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err) // Write our test configs/code - configFilepath := filepath.Join(dir, global.ConfigFilename) + configFilepath := filepath.Join(dir, "cog.yaml") file, err := os.Create(configFilepath) require.NoError(t, err) _, err = file.WriteString(`build: @@ -56,7 +55,7 @@ class Predictor(BasePredictor): // Perform the migration migrator := NewMigratorV1ToV1Fast(false) - err = migrator.Migrate(t.Context()) + err = migrator.Migrate(t.Context(), "cog.yaml") require.NoError(t, err) // Check config output