Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"

"github.com/replicate/cog/pkg/coglog"
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/http"
"github.com/replicate/cog/pkg/image"
"github.com/replicate/cog/pkg/util/console"
)
Expand Down Expand Up @@ -57,8 +60,17 @@ func newBuildCommand() *cobra.Command {
func buildCommand(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

command := docker.NewDockerCommand()
client, err := http.ProvideHTTPClient(ctx, command)
if err != nil {
return err
}
logClient := coglog.NewClient(client)
logCtx := logClient.StartBuild(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig(projectDirFlag)
if err != nil {
logClient.EndBuild(ctx, err, logCtx)
return err
}
if cfg.Build.Fast {
Expand All @@ -75,14 +87,17 @@ func buildCommand(cmd *cobra.Command, args []string) error {

err = config.ValidateModelPythonVersion(cfg)
if err != nil {
logClient.EndBuild(ctx, err, logCtx)
return err
}

if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, nil, buildLocalImage); err != nil {
if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, nil, buildLocalImage, command); err != nil {
logClient.EndBuild(ctx, err, logCtx)
return err
}

console.Infof("\nImage built as %s", imageName)
logClient.EndBuild(ctx, nil, logCtx)

return nil
}
Expand Down
37 changes: 29 additions & 8 deletions pkg/cli/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (

"github.com/replicate/go/uuid"

"github.com/replicate/cog/pkg/coglog"
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/http"
"github.com/replicate/cog/pkg/image"
"github.com/replicate/cog/pkg/util/console"
)
Expand Down Expand Up @@ -44,8 +46,17 @@ func newPushCommand() *cobra.Command {
func push(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

command := docker.NewDockerCommand()
client, err := http.ProvideHTTPClient(ctx, command)
if err != nil {
return err
}
logClient := coglog.NewClient(client)
logCtx := logClient.StartPush(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig(projectDirFlag)
if err != nil {
logClient.EndPush(ctx, err, logCtx)
return err
}
if cfg.Build.Fast {
Expand All @@ -58,17 +69,23 @@ func push(cmd *cobra.Command, args []string) error {
}

if imageName == "" {
return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push r8.im/your-username/hotdog-detector'")
err = fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push r8.im/your-username/hotdog-detector'")
logClient.EndPush(ctx, err, logCtx)
return err
}

replicatePrefix := fmt.Sprintf("%s/", global.ReplicateRegistryHost)
if strings.HasPrefix(imageName, replicatePrefix) {
if err := docker.ManifestInspect(ctx, imageName); err != nil && strings.Contains(err.Error(), `"code":"NAME_UNKNOWN"`) {
return fmt.Errorf("Unable to find Replicate existing model for %s. Go to replicate.com and create a new model before pushing.", imageName)
err = fmt.Errorf("Unable to find Replicate existing model for %s. Go to replicate.com and create a new model before pushing.", imageName)
logClient.EndPush(ctx, err, logCtx)
return err
}
} else {
if buildLocalImage {
return fmt.Errorf("Unable to push a local image model to a non replicate host, please disable the local image flag before pushing to this host.")
err = fmt.Errorf("Unable to push a local image model to a non replicate host, please disable the local image flag before pushing to this host.")
logClient.EndPush(ctx, err, logCtx)
return err
}
}

Expand All @@ -83,7 +100,7 @@ func push(cmd *cobra.Command, args []string) error {

startBuildTime := time.Now()

if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, annotations, buildLocalImage); err != nil {
if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, annotations, buildLocalImage, command); err != nil {
return err
}

Expand All @@ -94,14 +111,13 @@ func push(cmd *cobra.Command, args []string) error {
console.Info("Fast push enabled.")
}

command := docker.NewDockerCommand()
err = docker.Push(ctx, imageName, buildFast, projectDir, command, docker.BuildInfo{
BuildTime: buildDuration,
BuildID: buildID.String(),
})
}, client)
if err != nil {
if strings.Contains(err.Error(), "404") {
return fmt.Errorf("Unable to find existing Replicate model for %s. "+
err = fmt.Errorf("Unable to find existing Replicate model for %s. "+
"Go to replicate.com and create a new model before pushing."+
"\n\n"+
"If the model already exists, you may be getting this error "+
Expand All @@ -110,15 +126,20 @@ func push(cmd *cobra.Command, args []string) error {
"or `sudo cog push` instead of `cog push`, "+
"which causes Docker to use the wrong Docker credentials.",
imageName)
logClient.EndPush(ctx, err, logCtx)
return err
}
return fmt.Errorf("Failed to push image: %w", err)
err = fmt.Errorf("Failed to push image: %w", err)
logClient.EndPush(ctx, err, logCtx)
return err
}

console.Infof("Image '%s' pushed", imageName)
if strings.HasPrefix(imageName, replicatePrefix) {
replicatePage := fmt.Sprintf("https://%s", strings.Replace(imageName, global.ReplicateRegistryHost, global.ReplicateWebsiteHost, 1))
console.Infof("\nRun your model on Replicate:\n %s", replicatePage)
}
logClient.EndPush(ctx, nil, logCtx)

return nil
}
163 changes: 163 additions & 0 deletions pkg/coglog/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package coglog

import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/replicate/cog/pkg/env"
"github.com/replicate/cog/pkg/util/console"
)

type Client struct {
client *http.Client
}

type BuildLogContext struct {
started time.Time
fast bool
localImage bool
}

type PushLogContext struct {
started time.Time
fast bool
localImage bool
}

type buildLog struct {
DurationMs float32 `json:"length_ms"`
BuildError *string `json:"error"`
Fast bool `json:"fast"`
LocalImage bool `json:"local_image"`
}

type pushLog struct {
DurationMs float32 `json:"length_ms"`
BuildError *string `json:"error"`
Fast bool `json:"fast"`
LocalImage bool `json:"local_image"`
}

func NewClient(client *http.Client) *Client {
return &Client{
client: client,
}
}

func (c *Client) StartBuild(fast bool, localImage bool) BuildLogContext {
logContext := BuildLogContext{
started: time.Now(),
fast: fast,
localImage: localImage,
}
return logContext
}

func (c *Client) EndBuild(ctx context.Context, err error, logContext BuildLogContext) bool {
var errorStr *string = nil
if err != nil {
errStr := err.Error()
errorStr = &errStr
}
buildLog := buildLog{
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
BuildError: errorStr,
Fast: logContext.fast,
LocalImage: logContext.localImage,
}

jsonData, err := json.Marshal(buildLog)
if err != nil {
console.Warn("Failed to marshal JSON for build log: " + err.Error())
return false
}

err = c.postLog(ctx, jsonData)
if err != nil {
console.Warn(err.Error())
return false
}

return true
}

func (c *Client) StartPush(fast bool, localImage bool) PushLogContext {
logContext := PushLogContext{
started: time.Now(),
fast: fast,
localImage: localImage,
}
return logContext
}

func (c *Client) EndPush(ctx context.Context, err error, logContext PushLogContext) bool {
var errorStr *string = nil
if err != nil {
errStr := err.Error()
errorStr = &errStr
}
pushLog := pushLog{
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
BuildError: errorStr,
Fast: logContext.fast,
LocalImage: logContext.localImage,
}

jsonData, err := json.Marshal(pushLog)
if err != nil {
console.Warn("Failed to marshal JSON for build log: " + err.Error())
return false
}

err = c.postLog(ctx, jsonData)
if err != nil {
console.Warn(err.Error())
return false
}

return true
}

func (c *Client) postLog(ctx context.Context, jsonData []byte) error {
disabled, err := DisableFromEnvironment()
if err != nil {
return err
}
if disabled {
return errors.New("Cog logging disabled")
}

url := buildURL()
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url.String(), bytes.NewReader(jsonData))
if err != nil {
return err
}
resp, err := c.client.Do(req)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return errors.New("Bad response from build log: " + strconv.Itoa(resp.StatusCode))
}
return nil
}

func baseURL() url.URL {
return url.URL{
Scheme: env.SchemeFromEnvironment(),
Host: HostFromEnvironment(),
}
}

func buildURL() url.URL {
url := baseURL()
url.Path = strings.Join([]string{"", "v1", "build"}, "/")
return url
}
54 changes: 54 additions & 0 deletions pkg/coglog/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package coglog

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/require"

"github.com/replicate/cog/pkg/env"
)

func TestLogBuild(t *testing.T) {
// Setup mock http server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
url, err := url.Parse(server.URL)
require.NoError(t, err)
t.Setenv(env.SchemeEnvVarName, url.Scheme)
t.Setenv(CoglogHostEnvVarName, url.Host)

client := NewClient(http.DefaultClient)
logContext := client.StartBuild(false, false)
success := client.EndBuild(t.Context(), nil, logContext)
require.True(t, success)
}

func TestLogBuildDisabled(t *testing.T) {
t.Setenv(CoglogDisableEnvVarName, "true")
client := NewClient(http.DefaultClient)
logContext := client.StartBuild(false, false)
success := client.EndBuild(t.Context(), nil, logContext)
require.False(t, success)
}

func TestLogPush(t *testing.T) {
// Setup mock http server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
url, err := url.Parse(server.URL)
require.NoError(t, err)
t.Setenv(env.SchemeEnvVarName, url.Scheme)
t.Setenv(CoglogHostEnvVarName, url.Host)

client := NewClient(http.DefaultClient)
logContext := client.StartPush(false, false)
success := client.EndPush(t.Context(), nil, logContext)
require.True(t, success)
}
25 changes: 25 additions & 0 deletions pkg/coglog/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package coglog

import (
"os"
"strconv"
)

const CoglogHostEnvVarName = "R8_COGLOG_HOST"
const CoglogDisableEnvVarName = "R8_COGLOG_DISABLE"

func HostFromEnvironment() string {
host := os.Getenv(CoglogHostEnvVarName)
if host == "" {
host = "coglog.replicate.delivery"
}
return host
}

func DisableFromEnvironment() (bool, error) {
disable := os.Getenv(CoglogDisableEnvVarName)
if disable == "" {
disable = "false"
}
return strconv.ParseBool(disable)
}
Loading