diff --git a/pkg/cli/build.go b/pkg/cli/build.go index ea17bee44f..2748ffd512 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -68,7 +68,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { logClient := coglog.NewClient(client) logCtx := logClient.StartBuild(buildFast, buildLocalImage) - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { logClient.EndBuild(ctx, err, logCtx) return err diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 76c96b7935..3b5eccc1e4 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -37,7 +37,7 @@ func newDebugCommand() *cobra.Command { func cmdDockerfile(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { return err } diff --git a/pkg/cli/migrate.go b/pkg/cli/migrate.go new file mode 100644 index 0000000000..ac794e3abf --- /dev/null +++ b/pkg/cli/migrate.go @@ -0,0 +1,45 @@ +package cli + +import ( + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/migrate" +) + +var migrateAccept bool + +func newMigrateCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "migrate", + Short: "Run a migration", + Long: `Run a migration. + +This will attempt to migrate your cog project to be compatible with fast boots.`, + RunE: cmdMigrate, + Args: cobra.MaximumNArgs(0), + Hidden: true, + } + + addYesFlag(cmd) + + return cmd +} + +func cmdMigrate(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + migrator, err := migrate.NewMigrator(migrate.MigrationV1, migrate.MigrationV1Fast, !migrateAccept) + if err != nil { + return err + } + err = migrator.Migrate(ctx) + if err != nil { + return err + } + + return nil +} + +func addYesFlag(cmd *cobra.Command) { + const acceptFlag = "y" + cmd.Flags().BoolVar(&migrateAccept, acceptFlag, false, "Whether to disable interaction and automatically accept the changes.") +} diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index fc29a94d57..2018c3bc4d 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -78,7 +78,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { if len(args) == 0 { // Build image - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { return err } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 3b8e07ac0a..651879e3da 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -54,7 +54,7 @@ func push(cmd *cobra.Command, args []string) error { logClient := coglog.NewClient(client) logCtx := logClient.StartPush(buildFast, buildLocalImage) - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { logClient.EndPush(ctx, err, logCtx) return err diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 1f55d39cde..471e3c217e 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -10,8 +10,6 @@ import ( "github.com/replicate/cog/pkg/util/console" ) -var projectDirFlag string - func NewRootCommand() (*cobra.Command, error) { rootCmd := cobra.Command{ Use: "cog", @@ -47,6 +45,7 @@ https://github.com/replicate/cog`, newRunCommand(), newServeCommand(), newTrainCommand(), + newMigrateCommand(), ) return &rootCmd, nil diff --git a/pkg/cli/run.go b/pkg/cli/run.go index 3636e27c53..9cce45f0b3 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -54,7 +54,7 @@ func newRunCommand() *cobra.Command { func run(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { return err } diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index 662d8e7e31..9857071de0 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -43,7 +43,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { return err } diff --git a/pkg/cli/train.go b/pkg/cli/train.go index a0deba590c..7b239dd267 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -61,7 +61,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { volumes := []docker.Volume{} gpus := gpusFlag - cfg, projectDir, err := config.GetConfig(projectDirFlag) + cfg, projectDir, err := config.GetConfig() if err != nil { return err } diff --git a/pkg/config/config.go b/pkg/config/config.go index e2e1342a82..70c624a1fe 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -41,21 +41,21 @@ type RunItem struct { Type string `json:"type,omitempty" yaml:"type"` ID string `json:"id,omitempty" yaml:"id"` Target string `json:"target,omitempty" yaml:"target"` - } `json:"mounts,omitempty" yaml:"mounts"` + } `json:"mounts,omitempty" yaml:"mounts,omitempty"` } type Build struct { - GPU bool `json:"gpu,omitempty" yaml:"gpu"` + GPU bool `json:"gpu,omitempty" yaml:"gpu,omitempty"` PythonVersion string `json:"python_version,omitempty" yaml:"python_version"` PythonRequirements string `json:"python_requirements,omitempty" yaml:"python_requirements"` - PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages"` // Deprecated, but included for backwards compatibility - Run []RunItem `json:"run,omitempty" yaml:"run"` - SystemPackages []string `json:"system_packages,omitempty" yaml:"system_packages"` - PreInstall []string `json:"pre_install,omitempty" yaml:"pre_install"` // Deprecated, but included for backwards compatibility - CUDA string `json:"cuda,omitempty" yaml:"cuda"` - CuDNN string `json:"cudnn,omitempty" yaml:"cudnn"` + PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages,omitempty"` // Deprecated, but included for backwards compatibility + Run []RunItem `json:"run,omitempty" yaml:"run,omitempty"` + SystemPackages []string `json:"system_packages,omitempty" yaml:"system_packages,omitempty"` + PreInstall []string `json:"pre_install,omitempty" yaml:"pre_install,omitempty"` // Deprecated, but included for backwards compatibility + CUDA string `json:"cuda,omitempty" yaml:"cuda,omitempty"` + CuDNN string `json:"cudnn,omitempty" yaml:"cudnn,omitempty"` Fast bool `json:"fast,omitempty" yaml:"fast"` - PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides"` + PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides,omitempty"` pythonRequirementsContent []string } @@ -71,10 +71,10 @@ type Example struct { type Config struct { Build *Build `json:"build" yaml:"build"` - Image string `json:"image,omitempty" yaml:"image"` + Image string `json:"image,omitempty" yaml:"image,omitempty"` Predict string `json:"predict,omitempty" yaml:"predict"` - Train string `json:"train,omitempty" yaml:"train"` - Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"` + Train string `json:"train,omitempty" yaml:"train,omitempty"` + Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency,omitempty"` } func DefaultConfig() *Config { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 75f4c46a71..121a7a1e6f 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -719,3 +719,15 @@ build: _, err := FromYAML([]byte(yamlString)) require.NoError(t, err) } + +func TestConfigMarshal(t *testing.T) { + cfg := DefaultConfig() + data, err := yaml.Marshal(cfg) + require.NoError(t, err) + require.Equal(t, `build: + python_version: "3.12" + python_requirements: "" + fast: false +predict: "" +`, string(data)) +} diff --git a/pkg/config/load.go b/pkg/config/load.go index 9a56cd1321..c7ba9a5937 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -14,11 +14,7 @@ import ( const maxSearchDepth = 100 // Returns the project's root directory, or the directory specified by the --project-dir flag -func GetProjectDir(customDir string) (string, error) { - if customDir != "" { - return customDir, nil - } - +func GetProjectDir() (string, error) { cwd, err := os.Getwd() if err != nil { return "", err @@ -28,9 +24,9 @@ func GetProjectDir(customDir string) (string, error) { // Loads and instantiates a Config object // customDir can be specified to override the default - current working directory -func GetConfig(customDir string) (*Config, string, error) { +func GetConfig() (*Config, string, error) { // Find the root project directory - rootDir, err := GetProjectDir(customDir) + rootDir, err := GetProjectDir() if err != nil { return nil, "", err } diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index 2bc34760e0..78d15904b6 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -18,27 +18,6 @@ build: predict: "predict.py:SomePredictor" ` -func TestGetProjectDirWithFlagSet(t *testing.T) { - projectDirFlag := "foo" - - projectDir, err := GetProjectDir(projectDirFlag) - require.NoError(t, err) - require.Equal(t, projectDir, projectDirFlag) -} - -func TestGetConfigShouldLoadFromCustomDir(t *testing.T) { - dir := t.TempDir() - - err := os.WriteFile(path.Join(dir, "cog.yaml"), []byte(testConfig), 0o644) - require.NoError(t, err) - err = os.WriteFile(path.Join(dir, "requirements.txt"), []byte("torch==1.0.0"), 0o644) - require.NoError(t, err) - conf, _, err := GetConfig(dir) - require.NoError(t, err) - require.Equal(t, conf.Predict, "predict.py:SomePredictor") - require.Equal(t, conf.Build.PythonVersion, "3.8") -} - func TestFindProjectRootDirShouldFindParentDir(t *testing.T) { projectDir := t.TempDir() diff --git a/pkg/dockerfile/cog_embed.go b/pkg/dockerfile/cog_embed.go index 3711f460e3..2f795f1e55 100644 --- a/pkg/dockerfile/cog_embed.go +++ b/pkg/dockerfile/cog_embed.go @@ -1,6 +1,35 @@ package dockerfile -import "embed" +import ( + "embed" + "fmt" + "path/filepath" +) + +const EmbedDir = "embed" //go:embed embed/*.whl var CogEmbed embed.FS + +func WheelFilename() (string, error) { + files, err := CogEmbed.ReadDir(EmbedDir) + if err != nil { + return "", err + } + if len(files) != 1 { + return "", fmt.Errorf("should only have one cog wheel embedded") + } + return files[0].Name(), nil +} + +func ReadWheelFile() ([]byte, string, error) { + filename, err := WheelFilename() + if err != nil { + return nil, "", err + } + data, err := CogEmbed.ReadFile(filepath.Join(EmbedDir, filename)) + if err != nil { + return nil, "", err + } + return data, filename, err +} diff --git a/pkg/dockerfile/cog_embed_test.go b/pkg/dockerfile/cog_embed_test.go new file mode 100644 index 0000000000..9cd46274c9 --- /dev/null +++ b/pkg/dockerfile/cog_embed_test.go @@ -0,0 +1,14 @@ +package dockerfile + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWheelFilename(t *testing.T) { + filename, err := WheelFilename() + require.NoError(t, err) + require.True(t, strings.HasPrefix(filename, "cog-")) +} diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index eb28a47c86..22b6b7c33a 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -415,15 +415,7 @@ RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /u } func (g *StandardGenerator) installCog() (string, error) { - files, err := CogEmbed.ReadDir("embed") - if err != nil { - return "", err - } - if len(files) != 1 { - return "", fmt.Errorf("should only have one cog wheel embedded") - } - filename := files[0].Name() - data, err := CogEmbed.ReadFile("embed/" + filename) + data, filename, err := ReadWheelFile() if err != nil { return "", err } diff --git a/pkg/migrate/factory.go b/pkg/migrate/factory.go new file mode 100644 index 0000000000..38db80199e --- /dev/null +++ b/pkg/migrate/factory.go @@ -0,0 +1,21 @@ +package migrate + +import ( + "errors" + "fmt" +) + +func NewMigrator(from Migration, to Migration, interactive bool) (Migrator, error) { + if from == MigrationV1 && to == MigrationV1Fast { + return NewMigratorV1ToV1Fast(interactive), nil + } + fromStr, err := MigrationToStr(from) + if err != nil { + return nil, err + } + toStr, err := MigrationToStr(to) + if err != nil { + return nil, err + } + return nil, errors.New(fmt.Sprintf("Unable to find a migrator from %s to %s.", fromStr, toStr)) +} diff --git a/pkg/migrate/factory_test.go b/pkg/migrate/factory_test.go new file mode 100644 index 0000000000..a3903a144d --- /dev/null +++ b/pkg/migrate/factory_test.go @@ -0,0 +1,13 @@ +package migrate + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMigrator(t *testing.T) { + migrator, err := NewMigrator(MigrationV1, MigrationV1Fast, false) + require.NoError(t, err) + require.NotNil(t, migrator) +} diff --git a/pkg/migrate/migrations.go b/pkg/migrate/migrations.go new file mode 100644 index 0000000000..b38504c035 --- /dev/null +++ b/pkg/migrate/migrations.go @@ -0,0 +1,23 @@ +package migrate + +import ( + "errors" + "fmt" +) + +type Migration int + +const ( + MigrationV1 Migration = iota + MigrationV1Fast +) + +func MigrationToStr(migration Migration) (string, error) { + switch migration { + case MigrationV1: + return "v1", nil + case MigrationV1Fast: + return "v1fast", nil + } + return "", errors.New(fmt.Sprintf("Unrecognized Migration: %d", migration)) +} diff --git a/pkg/migrate/migrations_test.go b/pkg/migrate/migrations_test.go new file mode 100644 index 0000000000..ca4e513cdc --- /dev/null +++ b/pkg/migrate/migrations_test.go @@ -0,0 +1,13 @@ +package migrate + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMigrationToStr(t *testing.T) { + migrationStr, err := MigrationToStr(MigrationV1) + require.NoError(t, err) + require.Equal(t, migrationStr, "v1") +} diff --git a/pkg/migrate/migrator.go b/pkg/migrate/migrator.go new file mode 100644 index 0000000000..c1faa96629 --- /dev/null +++ b/pkg/migrate/migrator.go @@ -0,0 +1,7 @@ +package migrate + +import "context" + +type Migrator interface { + Migrate(ctx context.Context) error +} diff --git a/pkg/migrate/migrator_v1_v1fast.go b/pkg/migrate/migrator_v1_v1fast.go new file mode 100644 index 0000000000..e348a7549a --- /dev/null +++ b/pkg/migrate/migrator_v1_v1fast.go @@ -0,0 +1,312 @@ +package migrate + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + + "archive/zip" + + "gopkg.in/yaml.v2" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/dockerfile" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/requirements" + "github.com/replicate/cog/pkg/util" + "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/util/files" +) + +const CogRequirementsFile = "cog_requirements.txt" +const MigrateV1V1FastPythonFile = "migrate_v1_v1fast.py" + +var IgnoredRunCommands = map[string]bool{ + "curl -o /usr/local/bin/pget -L \\\"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\\\" && chmod +x /usr/local/bin/pget": true, +} + +type MigratorV1ToV1Fast struct { + Interactive bool +} + +func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast { + return &MigratorV1ToV1Fast{ + Interactive: interactive, + } +} + +func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error { + cfg, projectDir, err := config.GetConfig() + if err != nil { + return err + } + err = g.checkPythonRequirements(cfg, projectDir) + if err != nil { + return err + } + err = g.checkRunCommands(cfg) + if err != nil { + return err + } + err = g.checkPythonCode(ctx, cfg, projectDir) + if err != nil { + return err + } + err = g.flushConfig(cfg, projectDir) + return err +} + +func (g *MigratorV1ToV1Fast) checkPythonRequirements(cfg *config.Config, dir string) error { + if cfg.Build == nil { + return nil + } + if len(cfg.Build.PythonPackages) == 0 { + return nil + } + console.Info("You have python_packages in your configuration, this is now deprecated and replaced with python_requirements.") + accept := true + if g.Interactive { + interactive := &console.InteractiveBool{ + Prompt: "Would you like to move your python_packages to a requirements.txt?", + Default: true, + NonDefaultFlag: "--y", + } + iAccept, err := interactive.Read() + if err != nil { + return err + } + accept = iAccept + } + if !accept { + console.Error("Skipping python_packages to python_requirements migration, this will cause issues on builds for fast boots.") + return nil + } + requirementsFile := filepath.Join(dir, requirements.RequirementsFile) + exists, err := files.Exists(requirementsFile) + if err != nil { + return err + } + if exists { + // If requirements.txt exists, we will write to an alternative requirements file to prevent overloading. + requirementsFile = filepath.Join(dir, CogRequirementsFile) + } + requirementsContent := strings.Join(cfg.Build.PythonPackages, "\n") + console.Infof("Writing python_packages to %s.", requirementsFile) + file, err := os.Create(requirementsFile) + if err != nil { + return err + } + defer file.Close() + _, err = file.WriteString(requirementsContent) + if err != nil { + return err + } + cfg.Build.PythonPackages = []string{} + cfg.Build.PythonRequirements = filepath.Base(requirementsFile) + return nil +} + +func (g *MigratorV1ToV1Fast) checkRunCommands(cfg *config.Config) error { + if cfg.Build == nil { + return nil + } + if len(cfg.Build.Run) == 0 { + return nil + } + // Filter run commands we can safely remove + safelyRemove := true + for _, runCommand := range cfg.Build.Run { + _, ok := IgnoredRunCommands[runCommand.Command] + if !ok { + console.Warnf("Failed to safely remove \"%s\"", runCommand.Command) + safelyRemove = false + break + } + } + if safelyRemove { + console.Info("Safely removing run commands.") + cfg.Build.Run = []config.RunItem{} + return nil + } + accept := true + if g.Interactive { + interactive := &console.InteractiveBool{ + Prompt: "You have run commands we do not recognize in your configuration, do you want us to remove them?", + Default: true, + NonDefaultFlag: "--y", + } + iAccept, err := interactive.Read() + if err != nil { + return err + } + accept = iAccept + } + if !accept { + console.Error("Skipping removing run commands, this will cause issues on builds for fast boots.") + } else { + console.Info("Removing run commands.") + cfg.Build.Run = []config.RunItem{} + } + return nil +} + +func (g *MigratorV1ToV1Fast) checkPythonCode(ctx context.Context, cfg *config.Config, dir string) error { + err := g.checkPredictor(ctx, cfg.Predict, dir) + if err != nil { + return err + } + err = g.checkPredictor(ctx, cfg.Train, dir) + if err != nil { + return err + } + return nil +} + +func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error { + if cfg.Build == nil { + cfg.Build = config.DefaultConfig().Build + } + cfg.Build.Fast = true + err := cfg.ValidateAndComplete("") + if err != nil { + return err + } + data, err := yaml.Marshal(cfg) + if err != nil { + return err + } + configStr := string(data) + + configFilepath := filepath.Join(dir, global.ConfigFilename) + file, err := os.Open(configFilepath) + if err != nil { + return err + } + content, err := io.ReadAll(file) + file.Close() + if err != nil { + return err + } + if configStr == string(content) { + return nil + } + + console.Infof("New cog.yaml:\n%s\n", configStr) + + accept := true + if g.Interactive { + interactive := &console.InteractiveBool{ + Prompt: "Do you want to apply the above config changes?", + Default: true, + NonDefaultFlag: "--y", + } + iAccept, err := interactive.Read() + if err != nil { + return err + } + accept = iAccept + } + if !accept { + console.Error("Skipping config changes, this may cause issues on builds for fast boots.") + return nil + } + + file, err = os.Create(configFilepath) + if err != nil { + return err + } + defer file.Close() + console.Infof("Writing config changes to %s.", configFilepath) + _, err = file.WriteString(configStr) + if err != nil { + return util.WrapError(err, "Failed to write config changes") + } + + return nil +} + +func (g *MigratorV1ToV1Fast) checkPredictor(ctx context.Context, predictor string, dir string) error { + if predictor == "" { + return nil + } + zippedBytes, _, err := dockerfile.ReadWheelFile() + if err != nil { + return err + } + reader := bytes.NewReader(zippedBytes) + zipReader, err := zip.NewReader(reader, int64(len(zippedBytes))) + if err != nil { + return err + } + for _, file := range zipReader.File { + if filepath.Base(file.Name) != MigrateV1V1FastPythonFile { + continue + } + return g.runPythonScript(ctx, file, predictor, dir) + } + + return errors.New("Could not find " + MigrateV1V1FastPythonFile) +} + +func (g *MigratorV1ToV1Fast) runPythonScript(ctx context.Context, file *zip.File, predictor string, dir string) error { + splitPredictor := strings.Split(predictor, ":") + pythonFilename := splitPredictor[0] + pythonPredictor := splitPredictor[1] + + fileReader, err := file.Open() + if err != nil { + return err + } + defer fileReader.Close() + extractedData, err := io.ReadAll(fileReader) + if err != nil { + return err + } + pythonCode := string(extractedData) + cmd := exec.CommandContext(ctx, "python3", "-c", pythonCode, pythonFilename, pythonPredictor) + var out strings.Builder + cmd.Stdout = &out + cmd.Stderr = os.Stderr + err = cmd.Run() + if err != nil { + return err + } + newContent := out.String() + if newContent == "" { + return nil + } + accept := true + if g.Interactive { + interactive := &console.InteractiveBool{ + Prompt: "Do you want to apply the above code changes?", + Default: true, + NonDefaultFlag: "--y", + } + iAccept, err := interactive.Read() + if err != nil { + return err + } + accept = iAccept + } + if !accept { + console.Error("Skipping code changes, this will cause issues on builds for fast boots.") + return nil + } + pythonFilepath := filepath.Join(dir, pythonFilename) + pythonFile, err := os.Create(pythonFilepath) + if err != nil { + return util.WrapError(err, "Could not open python predictor file") + } + defer pythonFile.Close() + console.Infof("Writing code changes to %s.", pythonFilepath) + _, err = pythonFile.WriteString(newContent) + if err != nil { + return util.WrapError(err, "Failed to write to python predictor file") + } + return nil +} diff --git a/pkg/migrate/migrator_v1_v1fast_test.go b/pkg/migrate/migrator_v1_v1fast_test.go new file mode 100644 index 0000000000..a3f8bfb570 --- /dev/null +++ b/pkg/migrate/migrator_v1_v1fast_test.go @@ -0,0 +1,100 @@ +package migrate + +import ( + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/requirements" +) + +func TestMigrate(t *testing.T) { + // Set our new working directory to a temp directory + originalDir, err := os.Getwd() + require.NoError(t, err) + defer func() { + err := os.Chdir(originalDir) + require.NoError(t, err) + }() + dir := t.TempDir() + err = os.Chdir(dir) + require.NoError(t, err) + + // Write our test configs/code + configFilepath := filepath.Join(dir, global.ConfigFilename) + file, err := os.Create(configFilepath) + require.NoError(t, err) + _, err = file.WriteString(`build: + python_version: "3.11" + fast: true + python_packages: + - "pillow" + run: + - command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget +predict: "predict.py:Predictor" +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + pythonFilepath := filepath.Join(dir, "predict.py") + file, err = os.Create(pythonFilepath) + require.NoError(t, err) + _, err = file.WriteString(`from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: str = Input(description="My Input Description", default=None)) -> str: + return "hello " + s +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + + // Perform the migration + migrator := NewMigratorV1ToV1Fast(false) + err = migrator.Migrate(t.Context()) + require.NoError(t, err) + + // Check config output + file, err = os.Open(configFilepath) + require.NoError(t, err) + content, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, `build: + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: predict.py:Predictor +`, string(content)) + err = file.Close() + require.NoError(t, err) + + // Check python code output + file, err = os.Open(pythonFilepath) + require.NoError(t, err) + content, err = io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, `from typing import Optional +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: Optional[str] = Input(description="My Input Description", default=None)) -> str: + return "hello " + s +`, string(content)) + err = file.Close() + require.NoError(t, err) + + // Check requirements.txt + file, err = os.Open(filepath.Join(dir, requirements.RequirementsFile)) + require.NoError(t, err) + content, err = io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, `pillow`, string(content)) + err = file.Close() + require.NoError(t, err) +} diff --git a/python/cog/command/migrate_v1_v1fast.py b/python/cog/command/migrate_v1_v1fast.py new file mode 100644 index 0000000000..a5e66f3970 --- /dev/null +++ b/python/cog/command/migrate_v1_v1fast.py @@ -0,0 +1,83 @@ +import ast +import sys +from typing import Any, List, Optional, Type, TypeVar + +T = TypeVar("T", covariant=True) + + +def find(nodes: List[Any], tpe: Type[T], attr: str, name: str) -> Optional[T]: + for n in nodes: + if type(n) is tpe and getattr(n, attr) == name: + return n + return None + + +def check(file: str, predictor: str) -> None: + with open(file, "r") as f: + content = f.read() + lines = content.splitlines() + root = ast.parse(content) + + p = find(root.body, ast.ClassDef, "name", predictor) + if p is None: + return + fn = find(p.body, ast.FunctionDef, "name", "predict") + if fn is None: + fn = find(p.body, ast.AsyncFunctionDef, "name", "predict") # type: ignore + args_and_defaults = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore + none_defaults = [] + for a, d in args_and_defaults: + if type(a.annotation) is not ast.Name: + continue + if type(d) is not ast.Call or d.func.id != "Input": # type: ignore + continue + v = find(d.keywords, ast.keyword, "arg", "default") + if v is None or type(v.value) is not ast.Constant: + continue + if v.value.value is None: + pos = f"{file}:{a.lineno}:{a.col_offset}" + + # Add `Optional[]` to type annotation + # No need to remove `default=None` since `x: Optional[T] = Input(default=None)` is valid + ta = a.annotation + line = lines[ta.lineno - 1] + parts = ( + line[: ta.col_offset], + line[ta.col_offset : ta.end_col_offset], + line[ta.end_col_offset :], + ) + line = f"{parts[0]}Optional[{parts[1]}]{parts[2]}" + lines[ta.lineno - 1] = line + + none_defaults.append(f"{pos}: {a.arg}: {ta.id}={ast.unparse(d)}") + + if len(none_defaults) > 0: + print( + "Default value of None without explicit Optional[T] type hint is ambiguous and deprecated, for example:", + file=sys.stderr, + ) + print("- x: str=Input(default=None)", file=sys.stderr) + print("+ x: Optional[str]=Input(default=None)", file=sys.stderr) + print(file=sys.stderr) + for line in none_defaults: + print(line, file=sys.stderr) + + # Check for `from typing import Optional` + imports = find(root.body, ast.ImportFrom, "module", "typing") + if imports is None or "Optional" not in [n.name for n in imports.names]: + # Missing import, add it at beginning of file or before first import + # Skip `#!/usr/bin/env python3` or comments + lno = 1 + while lines[lno - 1].startswith("#"): + lno += 1 + for n in root.body: + if type(n) in {ast.Import, ast.ImportFrom}: + lno = n.lineno + break + lines = ( + lines[: lno - 1] + ["from typing import Optional"] + lines[lno - 1 :] + ) + print("\n".join(lines)) + + +check(sys.argv[1], sys.argv[2]) diff --git a/test-integration/test_integration/fixtures/migration-project/cog.yaml b/test-integration/test_integration/fixtures/migration-project/cog.yaml new file mode 100644 index 0000000000..abb72fe389 --- /dev/null +++ b/test-integration/test_integration/fixtures/migration-project/cog.yaml @@ -0,0 +1,8 @@ +build: + python_version: "3.11" + fast: true + python_packages: + - "pillow" + run: + - command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/migration-project/predict.py b/test-integration/test_integration/fixtures/migration-project/predict.py new file mode 100644 index 0000000000..b50e415c63 --- /dev/null +++ b/test-integration/test_integration/fixtures/migration-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: str = Input(description="My Input Description", default=None)) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_migrate.py b/test-integration/test_integration/test_migrate.py new file mode 100644 index 0000000000..236cb69773 --- /dev/null +++ b/test-integration/test_integration/test_migrate.py @@ -0,0 +1,44 @@ +import os +import pathlib +import shutil +import subprocess +from pathlib import Path + +DEFAULT_TIMEOUT = 60 + + +def test_migrate(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/migration-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + result = subprocess.run( + [ + "cog", + "migrate", + "--y", + ], + cwd=out_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + assert result.returncode == 0 + with open(os.path.join(out_dir, "cog.yaml"), encoding="utf8") as handle: + assert handle.read(), """build: + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: predict.py:Predictor +""" + with open(os.path.join(out_dir, "predict.py"), encoding="utf8") as handle: + assert handle.read(), """from typing import Optional +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: Optional[str] = Input(description="My Input Description", default=None)) -> str: + return "hello " + s +""" + with open(os.path.join(out_dir, "requirements.txt"), encoding="utf8") as handle: + assert handle.read(), "pillow"