Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
logClient := coglog.NewClient(client)
logCtx := logClient.StartBuild(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
logClient.EndBuild(ctx, err, logCtx)
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func newDebugCommand() *cobra.Command {
func cmdDockerfile(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
return err
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/cli/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package cli

import (
"github.com/spf13/cobra"

"github.com/replicate/cog/pkg/migrate"
)

var migrateAccept bool

func newMigrateCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "migrate",
Short: "Run a migration",
Long: `Run a migration.

This will attempt to migrate your cog project to be compatible with fast boots.`,
RunE: cmdMigrate,
Args: cobra.MaximumNArgs(0),
Hidden: true,
}

addYesFlag(cmd)

return cmd
}

func cmdMigrate(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
migrator, err := migrate.NewMigrator(migrate.MigrationV1, migrate.MigrationV1Fast, !migrateAccept)
if err != nil {
return err
}
err = migrator.Migrate(ctx)
if err != nil {
return err
}

return nil
}

func addYesFlag(cmd *cobra.Command) {
const acceptFlag = "y"
cmd.Flags().BoolVar(&migrateAccept, acceptFlag, false, "Whether to disable interaction and automatically accept the changes.")
}
2 changes: 1 addition & 1 deletion pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
// Build image

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func push(cmd *cobra.Command, args []string) error {
logClient := coglog.NewClient(client)
logCtx := logClient.StartPush(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
logClient.EndPush(ctx, err, logCtx)
return err
Expand Down
3 changes: 1 addition & 2 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"github.com/replicate/cog/pkg/util/console"
)

var projectDirFlag string

func NewRootCommand() (*cobra.Command, error) {
rootCmd := cobra.Command{
Use: "cog",
Expand Down Expand Up @@ -47,6 +45,7 @@ https://github.com/replicate/cog`,
newRunCommand(),
newServeCommand(),
newTrainCommand(),
newMigrateCommand(),
)

return &rootCmd, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func newRunCommand() *cobra.Command {
func run(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.`
func cmdServe(cmd *cobra.Command, arg []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
volumes := []docker.Volume{}
gpus := gpusFlag

cfg, projectDir, err := config.GetConfig(projectDirFlag)
cfg, projectDir, err := config.GetConfig()
if err != nil {
return err
}
Expand Down
24 changes: 12 additions & 12 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ type RunItem struct {
Type string `json:"type,omitempty" yaml:"type"`
ID string `json:"id,omitempty" yaml:"id"`
Target string `json:"target,omitempty" yaml:"target"`
} `json:"mounts,omitempty" yaml:"mounts"`
} `json:"mounts,omitempty" yaml:"mounts,omitempty"`
}

type Build struct {
GPU bool `json:"gpu,omitempty" yaml:"gpu"`
GPU bool `json:"gpu,omitempty" yaml:"gpu,omitempty"`
PythonVersion string `json:"python_version,omitempty" yaml:"python_version"`
PythonRequirements string `json:"python_requirements,omitempty" yaml:"python_requirements"`
PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages"` // Deprecated, but included for backwards compatibility
Run []RunItem `json:"run,omitempty" yaml:"run"`
SystemPackages []string `json:"system_packages,omitempty" yaml:"system_packages"`
PreInstall []string `json:"pre_install,omitempty" yaml:"pre_install"` // Deprecated, but included for backwards compatibility
CUDA string `json:"cuda,omitempty" yaml:"cuda"`
CuDNN string `json:"cudnn,omitempty" yaml:"cudnn"`
PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages,omitempty"` // Deprecated, but included for backwards compatibility
Run []RunItem `json:"run,omitempty" yaml:"run,omitempty"`
SystemPackages []string `json:"system_packages,omitempty" yaml:"system_packages,omitempty"`
PreInstall []string `json:"pre_install,omitempty" yaml:"pre_install,omitempty"` // Deprecated, but included for backwards compatibility
CUDA string `json:"cuda,omitempty" yaml:"cuda,omitempty"`
CuDNN string `json:"cudnn,omitempty" yaml:"cudnn,omitempty"`
Fast bool `json:"fast,omitempty" yaml:"fast"`
PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides"`
PythonOverrides string `json:"python_overrides,omitempty" yaml:"python_overrides,omitempty"`

pythonRequirementsContent []string
}
Expand All @@ -71,10 +71,10 @@ type Example struct {

type Config struct {
Build *Build `json:"build" yaml:"build"`
Image string `json:"image,omitempty" yaml:"image"`
Image string `json:"image,omitempty" yaml:"image,omitempty"`
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train"`
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"`
Train string `json:"train,omitempty" yaml:"train,omitempty"`
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
}

func DefaultConfig() *Config {
Expand Down
12 changes: 12 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,15 @@ build:
_, err := FromYAML([]byte(yamlString))
require.NoError(t, err)
}

func TestConfigMarshal(t *testing.T) {
cfg := DefaultConfig()
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
require.Equal(t, `build:
python_version: "3.12"
python_requirements: ""
fast: false
predict: ""
`, string(data))
}
10 changes: 3 additions & 7 deletions pkg/config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ import (
const maxSearchDepth = 100

// Returns the project's root directory, or the directory specified by the --project-dir flag
func GetProjectDir(customDir string) (string, error) {
if customDir != "" {
return customDir, nil
}

func GetProjectDir() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
Expand All @@ -28,9 +24,9 @@ func GetProjectDir(customDir string) (string, error) {

// Loads and instantiates a Config object
// customDir can be specified to override the default - current working directory
func GetConfig(customDir string) (*Config, string, error) {
func GetConfig() (*Config, string, error) {
// Find the root project directory
rootDir, err := GetProjectDir(customDir)
rootDir, err := GetProjectDir()
if err != nil {
return nil, "", err
}
Expand Down
21 changes: 0 additions & 21 deletions pkg/config/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,6 @@ build:
predict: "predict.py:SomePredictor"
`

func TestGetProjectDirWithFlagSet(t *testing.T) {
projectDirFlag := "foo"

projectDir, err := GetProjectDir(projectDirFlag)
require.NoError(t, err)
require.Equal(t, projectDir, projectDirFlag)
}

func TestGetConfigShouldLoadFromCustomDir(t *testing.T) {
dir := t.TempDir()

err := os.WriteFile(path.Join(dir, "cog.yaml"), []byte(testConfig), 0o644)
require.NoError(t, err)
err = os.WriteFile(path.Join(dir, "requirements.txt"), []byte("torch==1.0.0"), 0o644)
require.NoError(t, err)
conf, _, err := GetConfig(dir)
require.NoError(t, err)
require.Equal(t, conf.Predict, "predict.py:SomePredictor")
require.Equal(t, conf.Build.PythonVersion, "3.8")
}

func TestFindProjectRootDirShouldFindParentDir(t *testing.T) {
projectDir := t.TempDir()

Expand Down
31 changes: 30 additions & 1 deletion pkg/dockerfile/cog_embed.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
package dockerfile

import "embed"
import (
"embed"
"fmt"
"path/filepath"
)

const EmbedDir = "embed"

//go:embed embed/*.whl
var CogEmbed embed.FS

func WheelFilename() (string, error) {
files, err := CogEmbed.ReadDir(EmbedDir)
if err != nil {
return "", err
}
if len(files) != 1 {
return "", fmt.Errorf("should only have one cog wheel embedded")
}
return files[0].Name(), nil
}

func ReadWheelFile() ([]byte, string, error) {
filename, err := WheelFilename()
if err != nil {
return nil, "", err
}
data, err := CogEmbed.ReadFile(filepath.Join(EmbedDir, filename))
if err != nil {
return nil, "", err
}
return data, filename, err
}
14 changes: 14 additions & 0 deletions pkg/dockerfile/cog_embed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dockerfile

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestWheelFilename(t *testing.T) {
filename, err := WheelFilename()
require.NoError(t, err)
require.True(t, strings.HasPrefix(filename, "cog-"))
}
10 changes: 1 addition & 9 deletions pkg/dockerfile/standard_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,7 @@ RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /u
}

func (g *StandardGenerator) installCog() (string, error) {
files, err := CogEmbed.ReadDir("embed")
if err != nil {
return "", err
}
if len(files) != 1 {
return "", fmt.Errorf("should only have one cog wheel embedded")
}
filename := files[0].Name()
data, err := CogEmbed.ReadFile("embed/" + filename)
data, filename, err := ReadWheelFile()
if err != nil {
return "", err
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/migrate/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package migrate

import (
"errors"
"fmt"
)

func NewMigrator(from Migration, to Migration, interactive bool) (Migrator, error) {
if from == MigrationV1 && to == MigrationV1Fast {
return NewMigratorV1ToV1Fast(interactive), nil
}
fromStr, err := MigrationToStr(from)
if err != nil {
return nil, err
}
toStr, err := MigrationToStr(to)
if err != nil {
return nil, err
}
return nil, errors.New(fmt.Sprintf("Unable to find a migrator from %s to %s.", fromStr, toStr))
}
13 changes: 13 additions & 0 deletions pkg/migrate/factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package migrate

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewMigrator(t *testing.T) {
migrator, err := NewMigrator(MigrationV1, MigrationV1Fast, false)
require.NoError(t, err)
require.NotNil(t, migrator)
}
23 changes: 23 additions & 0 deletions pkg/migrate/migrations.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package migrate

import (
"errors"
"fmt"
)

type Migration int

const (
MigrationV1 Migration = iota
MigrationV1Fast
)

func MigrationToStr(migration Migration) (string, error) {
switch migration {
case MigrationV1:
return "v1", nil
case MigrationV1Fast:
return "v1fast", nil
}
return "", errors.New(fmt.Sprintf("Unrecognized Migration: %d", migration))
}
13 changes: 13 additions & 0 deletions pkg/migrate/migrations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package migrate

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestMigrationToStr(t *testing.T) {
migrationStr, err := MigrationToStr(MigrationV1)
require.NoError(t, err)
require.Equal(t, migrationStr, "v1")
}
7 changes: 7 additions & 0 deletions pkg/migrate/migrator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package migrate

import "context"

type Migrator interface {
Migrate(ctx context.Context) error
}
Loading