Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions pkg/cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type VerifyResponse struct {
}

func newLoginCommand() *cobra.Command {
var cmd = &cobra.Command{
cmd := &cobra.Command{
Use: "login",
SuggestFor: []string{"auth", "authenticate", "authorize"},
Short: "Log in to Replicate Docker registry",
Expand All @@ -34,19 +34,15 @@ func newLoginCommand() *cobra.Command {
}

cmd.Flags().Bool("token-stdin", false, "Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token")
cmd.Flags().String("registry", global.ReplicateRegistryHost, "Registry host")
_ = cmd.Flags().MarkHidden("registry")

return cmd
}

func login(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

registryHost, err := cmd.Flags().GetString("registry")
if err != nil {
return err
}
// Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var)
registryHost := global.ReplicateRegistryHost
tokenStdin, err := cmd.Flags().GetBool("token-stdin")
if err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,7 @@ func setPersistentFlags(cmd *cobra.Command) {
cmd.PersistentFlags().BoolVar(&global.Debug, "debug", false, "Show debugging output")
cmd.PersistentFlags().BoolVar(&global.ProfilingEnabled, "profile", false, "Enable profiling")
cmd.PersistentFlags().Bool("version", false, "Show version of Cog")
cmd.PersistentFlags().StringVar(&global.ReplicateRegistryHost, "registry", global.ReplicateRegistryHost, "Registry host")
_ = cmd.PersistentFlags().MarkHidden("profile")
_ = cmd.PersistentFlags().MarkHidden("registry")
}
24 changes: 16 additions & 8 deletions pkg/docker/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/replicate/go/types/ptr"

"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/console"
)

Expand Down Expand Up @@ -70,17 +71,13 @@ func NewAPIClient(ctx context.Context, opts ...Option) (*apiClient, error) {
return nil, fmt.Errorf("error pinging docker daemon: %w", err)
}

authConfig := make(map[string]registry.AuthConfig)
userInfo, err := loadUserInformation(ctx, "r8.im")
// Load authentication for configured registry and any other registries that might be needed
authConfig, err := loadRegistryAuths(ctx, global.ReplicateRegistryHost)
if err != nil {
return nil, fmt.Errorf("error loading user information: %w, you may need to authenticate using cog login", err)
}
authConfig["r8.im"] = registry.AuthConfig{
Username: userInfo.Username,
Password: userInfo.Token,
ServerAddress: "r8.im",
}

// Add any additional auth configs passed via options
for _, opt := range clientOptions.authConfigs {
authConfig[opt.ServerAddress] = opt
}
Expand Down Expand Up @@ -209,8 +206,19 @@ func (c *apiClient) Push(ctx context.Context, imageRef string) error {

// eagerly set auth config, or do it async
var authConfig registry.AuthConfig
if auth, ok := c.authConfig[parsedName.Context().RegistryStr()]; ok {
registryHost := parsedName.Context().RegistryStr()
if auth, ok := c.authConfig[registryHost]; ok {
authConfig = auth
} else {
// Dynamically load authentication for this registry if not already loaded
authConfigs, err := loadRegistryAuths(ctx, registryHost)
if err == nil {
if auth, ok := authConfigs[registryHost]; ok {
authConfig = auth
// Cache the auth config for future use
c.authConfig[registryHost] = auth
}
}
}

var opts image.PushOptions
Expand Down
74 changes: 44 additions & 30 deletions pkg/docker/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/docker/docker/api/types/registry"

"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/console"
)

Expand Down Expand Up @@ -47,49 +48,62 @@ func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types

func loadRegistryAuths(ctx context.Context, registryHosts ...string) (map[string]registry.AuthConfig, error) {
conf := config.LoadDefaultConfigFile(os.Stderr)

out := make(map[string]registry.AuthConfig)

for _, host := range registryHosts {
console.Debugf("=== loadRegistryAuths %s", host)
// check the credentials store first if set
if conf.CredentialsStore != "" {
console.Debugf("=== loadRegistryAuths %s: credentials store set", host)
credsHelper, err := loadAuthFromCredentialsStore(ctx, conf.CredentialsStore, host)
if err != nil {
console.Debugf("=== loadRegistryAuths %s: error loading credentials store: %s", host, err)
return nil, err
}
console.Debugf("=== loadRegistryAuths %s: credentials store loaded", host)
out[host] = registry.AuthConfig{
Username: credsHelper.Username,
Password: credsHelper.Secret,
ServerAddress: host,
}
// Try loading auth for the requested host
auth, err := tryLoadAuthForHost(ctx, conf, host)
if err == nil && auth != nil {
out[host] = *auth
continue
}

// next, check if the auth config exists in the config file
if auth, ok := conf.AuthConfigs[host]; ok {
console.Debugf("=== loadRegistryAuths %s: auth config found in config file", host)
out[host] = registry.AuthConfig{
Username: auth.Username,
Password: auth.Password,
Auth: auth.Auth,
Email: auth.Email,
ServerAddress: host,
IdentityToken: auth.IdentityToken,
RegistryToken: auth.RegistryToken,
// FALLBACK: If requesting alternate registry and no auth found,
// try reusing r8.im credentials
if host != global.DefaultReplicateRegistryHost {
auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
if err == nil && auth != nil {
// Reuse credentials for the alternate registry
auth.ServerAddress = host // Update to new host
out[host] = *auth
console.Infof("Using existing %s credentials for %s", global.DefaultReplicateRegistryHost, host)
continue
}
continue
}

console.Debugf("=== loadRegistryAuths %s: no auth config found", host)
}

return out, nil
}

func tryLoadAuthForHost(ctx context.Context, conf *configfile.ConfigFile, host string) (*registry.AuthConfig, error) {
// Try credentials store first (e.g., osxkeychain, pass)
if conf.CredentialsStore != "" {
credsHelper, err := loadAuthFromCredentialsStore(ctx, conf.CredentialsStore, host)
if err == nil {
return &registry.AuthConfig{
Username: credsHelper.Username,
Password: credsHelper.Secret,
ServerAddress: host,
}, nil
}
}

// Fallback to config file
if auth, ok := conf.AuthConfigs[host]; ok {
return &registry.AuthConfig{
Username: auth.Username,
Password: auth.Password,
Auth: auth.Auth,
Email: auth.Email,
ServerAddress: host,
IdentityToken: auth.IdentityToken,
RegistryToken: auth.RegistryToken,
}, nil
}

return nil, fmt.Errorf("no credentials found for %s", host)
}

func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) {
var out strings.Builder
binary := dockerCredentialBinary(credsStore)
Expand Down
168 changes: 168 additions & 0 deletions pkg/docker/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package docker

import (
"context"
"path/filepath"
"testing"

"github.com/docker/cli/cli/config/configfile"
"github.com/docker/cli/cli/config/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/replicate/cog/pkg/global"
)

func TestLoadRegistryAuths_Fallback(t *testing.T) {
ctx := context.Background()

t.Run("uses credentials for requested host when available", func(t *testing.T) {
// Create a mock config with credentials for the requested host
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{
"registry.example.com": {
Username: "user1",
Password: "pass1",
},
},
}

auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
require.NoError(t, err)
require.NotNil(t, auth)
assert.Equal(t, "user1", auth.Username)
assert.Equal(t, "pass1", auth.Password)
assert.Equal(t, "registry.example.com", auth.ServerAddress)
})

t.Run("falls back to default registry credentials when alternate registry has no credentials", func(t *testing.T) {
// Set up a temporary docker config file
tmpDir := t.TempDir()
dockerConfigPath := filepath.Join(tmpDir, "config.json")

// Create a config file with credentials only for the default registry
conf := &configfile.ConfigFile{
Filename: dockerConfigPath,
AuthConfigs: map[string]types.AuthConfig{
global.DefaultReplicateRegistryHost: {
Username: "defaultuser",
Password: "defaultpass",
},
},
}
require.NoError(t, conf.Save())

// Point Docker to our test config
t.Setenv("DOCKER_CONFIG", tmpDir)

// Try loading auth for an alternate registry that doesn't have credentials
auths, err := loadRegistryAuths(ctx, "registry.example.com")
require.NoError(t, err)
require.NotNil(t, auths)

// Should have fallen back to default registry credentials
auth, ok := auths["registry.example.com"]
require.True(t, ok, "should have auth for registry.example.com")
assert.Equal(t, "defaultuser", auth.Username)
assert.Equal(t, "defaultpass", auth.Password)
assert.Equal(t, "registry.example.com", auth.ServerAddress, "server address should be updated to the requested host")
})

t.Run("does not fallback when requesting default registry", func(t *testing.T) {
// This test uses tryLoadAuthForHost directly to avoid credential store issues
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{},
}

// Try loading auth for the default registry
auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
require.Error(t, err, "should error when no credentials found")
assert.Nil(t, auth)
assert.Contains(t, err.Error(), "no credentials found")
})

t.Run("prefers direct credentials over fallback", func(t *testing.T) {
// Create a mock config with credentials for both registries
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{
global.DefaultReplicateRegistryHost: {
Username: "defaultuser",
Password: "defaultpass",
},
"registry.example.com": {
Username: "directuser",
Password: "directpass",
},
},
}

// Try loading auth for the alternate registry
auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
require.NoError(t, err)
require.NotNil(t, auth)

// Should use direct credentials, not fallback
assert.Equal(t, "directuser", auth.Username)
assert.Equal(t, "directpass", auth.Password)
assert.Equal(t, "registry.example.com", auth.ServerAddress)
})

t.Run("returns empty map when no credentials available", func(t *testing.T) {
// This test uses tryLoadAuthForHost to avoid credential store issues
// The loadRegistryAuths function doesn't error when no credentials are found,
// it just returns an empty map
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{},
}

// Try loading auth for an alternate registry (will fail)
auth1, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
require.Error(t, err)
assert.Nil(t, auth1)

// Try loading auth for default registry (will also fail)
auth2, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
require.Error(t, err)
assert.Nil(t, auth2)

// Since both fail, loadRegistryAuths would return an empty map
// (it doesn't error, just silently skips hosts without credentials)
})
}

func TestTryLoadAuthForHost(t *testing.T) {
ctx := context.Background()

t.Run("loads auth from config file", func(t *testing.T) {
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{
"registry.example.com": {
Username: "testuser",
Password: "testpass",
Auth: "dGVzdHVzZXI6dGVzdHBhc3M=",
Email: "[email protected]",
},
},
}

auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
require.NoError(t, err)
require.NotNil(t, auth)
assert.Equal(t, "testuser", auth.Username)
assert.Equal(t, "testpass", auth.Password)
assert.Equal(t, "dGVzdHVzZXI6dGVzdHBhc3M=", auth.Auth)
assert.Equal(t, "[email protected]", auth.Email)
assert.Equal(t, "registry.example.com", auth.ServerAddress)
})

t.Run("returns error when no auth found", func(t *testing.T) {
conf := &configfile.ConfigFile{
AuthConfigs: map[string]types.AuthConfig{},
}

auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
require.Error(t, err)
assert.Nil(t, auth)
assert.Contains(t, err.Error(), "no credentials found")
})
}
16 changes: 15 additions & 1 deletion pkg/global/global.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
package global

import "os"

const (
DefaultReplicateRegistryHost = "r8.im"
)

var (
Version = "dev"
Commit = ""
BuildTime = "none"
Debug = false
ProfilingEnabled = false
ReplicateRegistryHost = "r8.im"
ReplicateRegistryHost = getDefaultRegistryHost()
ReplicateWebsiteHost = "replicate.com"
LabelNamespace = "run.cog."
CogBuildArtifactsFolder = ".cog"
)

func getDefaultRegistryHost() string {
// Priority: flag will override at runtime, but env var provides default
if host := os.Getenv("COG_REGISTRY_HOST"); host != "" {
return host
}
return DefaultReplicateRegistryHost
}