diff --git a/go.mod b/go.mod index b6e6a0c711..4bc05a867d 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeonx/timeago v1.0.0-rc5 + golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/sync v0.14.0 golang.org/x/sys v0.33.0 @@ -272,7 +273,6 @@ require ( go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/crypto v0.37.0 // indirect golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.39.0 // indirect diff --git a/pkg/docker/apt.go b/pkg/docker/apt.go deleted file mode 100644 index 10a69643ff..0000000000 --- a/pkg/docker/apt.go +++ /dev/null @@ -1,81 +0,0 @@ -package docker - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "errors" - "fmt" - "os" - "path/filepath" - "sort" - "strings" - - "github.com/replicate/cog/pkg/docker/command" -) - -const aptTarballPrefix = "apt." -const aptTarballSuffix = ".tar.zst" - -func CreateAptTarball(ctx context.Context, tmpDir string, dockerCommand command.Command, packages ...string) (string, error) { - if len(packages) > 0 { - sort.Strings(packages) - hash := sha256.New() - hash.Write([]byte(strings.Join(packages, " "))) - hexHash := hex.EncodeToString(hash.Sum(nil)) - aptTarFile := aptTarballPrefix + hexHash + aptTarballSuffix - aptTarPath := filepath.Join(tmpDir, aptTarFile) - - if _, err := os.Stat(aptTarPath); errors.Is(err, os.ErrNotExist) { - // Remove previous apt tar files. - err = removeAptTarballs(tmpDir) - if err != nil { - return "", err - } - - // Create the apt tar file - _, err = dockerCommand.CreateAptTarFile(ctx, tmpDir, aptTarFile, packages...) - if err != nil { - return "", err - } - } - - return aptTarFile, nil - } - return "", nil -} - -func CurrentAptTarball(tmpDir string) (string, error) { - files, err := os.ReadDir(tmpDir) - if err != nil { - return "", fmt.Errorf("os read dir error: %w", err) - } - - for _, file := range files { - fileName := file.Name() - if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) { - return filepath.Join(tmpDir, fileName), nil - } - } - - return "", nil -} - -func removeAptTarballs(tmpDir string) error { - files, err := os.ReadDir(tmpDir) - if err != nil { - return err - } - - for _, file := range files { - fileName := file.Name() - if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) { - err = os.Remove(filepath.Join(tmpDir, fileName)) - if err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index c2a67a86a6..f397fc7485 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -15,8 +15,6 @@ type Command interface { Pull(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) Push(ctx context.Context, ref string) error LoadUserInformation(ctx context.Context, registryHost string) (*UserInfo, error) - CreateTarFile(ctx context.Context, ref string, tmpDir string, tarFile string, folder string) (string, error) - CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) ImageExists(ctx context.Context, ref string) (bool, error) ContainerLogs(ctx context.Context, containerID string, w io.Writer) error diff --git a/pkg/docker/command/errors.go b/pkg/docker/command/errors.go index 8bcc4ba9eb..66f35a1098 100644 --- a/pkg/docker/command/errors.go +++ b/pkg/docker/command/errors.go @@ -29,3 +29,5 @@ func (e *NotFoundError) Is(target error) bool { func IsNotFoundError(err error) bool { return errors.Is(err, &NotFoundError{}) } + +var ErrAuthorizationFailed = errors.New("authorization failed") diff --git a/pkg/docker/credentials.go b/pkg/docker/credentials.go new file mode 100644 index 0000000000..28b26332b9 --- /dev/null +++ b/pkg/docker/credentials.go @@ -0,0 +1,88 @@ +package docker + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "strings" + + "github.com/docker/cli/cli/config" + "github.com/docker/cli/cli/config/configfile" + "github.com/docker/cli/cli/config/types" + + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/util/console" +) + +func loadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { + conf := config.LoadDefaultConfigFile(os.Stderr) + credsStore := conf.CredentialsStore + if credsStore == "" { + authConf, err := loadAuthFromConfig(conf, registryHost) + if err != nil { + return nil, err + } + return &command.UserInfo{ + Token: authConf.Password, + Username: authConf.Username, + }, nil + } + credsHelper, err := loadAuthFromCredentialsStore(ctx, credsStore, registryHost) + if err != nil { + return nil, err + } + return &command.UserInfo{ + Token: credsHelper.Secret, + Username: credsHelper.Username, + }, nil +} + +func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types.AuthConfig, error) { + return conf.AuthConfigs[registryHost], nil +} + +func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) { + var out strings.Builder + binary := dockerCredentialBinary(credsStore) + cmd := exec.CommandContext(ctx, binary, "get") + cmd.Env = os.Environ() + cmd.Stdout = &out + cmd.Stderr = &out + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + defer stdin.Close() + console.Debug("$ " + strings.Join(cmd.Args, " ")) + err = cmd.Start() + if err != nil { + return nil, err + } + _, err = io.WriteString(stdin, registryHost) + if err != nil { + return nil, err + } + err = stdin.Close() + if err != nil { + return nil, err + } + err = cmd.Wait() + if err != nil { + return nil, fmt.Errorf("exec wait error: %w", err) + } + + var config CredentialHelperInput + err = json.Unmarshal([]byte(out.String()), &config) + if err != nil { + return nil, err + } + + return &config, nil +} + +func dockerCredentialBinary(credsStore string) string { + return "docker-credential-" + credsStore +} diff --git a/pkg/docker/docker_client_test.go b/pkg/docker/docker_client_test.go index d5abbc919a..eb85c45290 100644 --- a/pkg/docker/docker_client_test.go +++ b/pkg/docker/docker_client_test.go @@ -1,12 +1,18 @@ package docker import ( + "bytes" + "net" + "strings" "testing" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/registry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/docker/dockertest" @@ -18,163 +24,411 @@ func TestDockerClient(t *testing.T) { t.Skip("skipping docker client tests in short mode") } - suite := &DockerClientSuite{ - dockerHelper: dockertest.NewHelperClient(t), - dockerClient: NewDockerCommand(), - } - - t.Run("ImageInspect", suite.runImageInspectTests) - t.Run("Pull", suite.runPullTests) - t.Run("ContainerStop", suite.runContainerStopTests) -} - -type DockerClientSuite struct { - dockerHelper *dockertest.HelperClient - dockerClient command.Command + client := NewDockerCommand() + runDockerClientTests(t, client) } -func (s *DockerClientSuite) assertImageExists(t *testing.T, imageRef string) { - inspect, err := s.dockerClient.Inspect(t.Context(), imageRef) - assert.NoError(t, err, "Failed to inspect image %q", imageRef) - assert.NotNil(t, inspect, "Image should exist") -} +func runDockerClientTests(t *testing.T, dockerClient command.Command) { + dockerHelper := dockertest.NewHelperClient(t) + testRegistry := registry_testhelpers.StartTestRegistry(t) -func (s *DockerClientSuite) assertNoImageExists(t *testing.T, imageRef string) { - inspect, err := s.dockerClient.Inspect(t.Context(), imageRef) - assert.ErrorIs(t, err, &command.NotFoundError{}, "Image should not exist") - assert.Nil(t, inspect, "Image should not exist") -} + dockerHelper.CleanupImages(t) -func (s *DockerClientSuite) runImageInspectTests(t *testing.T) { - t.Run("ExistingLocalImage", func(t *testing.T) { + t.Run("ImageInspect", func(t *testing.T) { t.Parallel() - image := "docker.io/library/busybox:latest" + t.Run("ExistingLocalImage", func(t *testing.T) { + t.Parallel() + + ref := dockertest.NewRef(t) + dockerHelper.ImageFixture(t, "alpine", ref.String()) - s.dockerHelper.MustPullImage(t, image) + expectedImage := dockerHelper.InspectImage(t, ref.String()) + resp, err := dockerClient.Inspect(t.Context(), ref.String()) + require.NoError(t, err, "Failed to inspect image %q", ref.String()) + assert.Equal(t, expectedImage.ID, resp.ID) + }) - expectedImage := s.dockerHelper.InspectImage(t, image) - resp, err := s.dockerClient.Inspect(t.Context(), image) - require.NoError(t, err, "Failed to inspect image %q", image) - assert.Equal(t, expectedImage.ID, resp.ID) + t.Run("MissingLocalImage", func(t *testing.T) { + t.Parallel() + + image := "not-a-valid-image" + _, err := dockerClient.Inspect(t.Context(), image) + assert.ErrorIs(t, err, &command.NotFoundError{}) + assert.ErrorContains(t, err, "image not found") + }) }) - t.Run("MissingLocalImage", func(t *testing.T) { + t.Run("Pull", func(t *testing.T) { t.Parallel() - image := "not-a-valid-image" - _, err := s.dockerClient.Inspect(t.Context(), image) - assert.ErrorIs(t, err, &command.NotFoundError{}) - assert.ErrorContains(t, err, "image not found") + // TODO[md]: add tests for the following permutations: + // - remote reference exists/not exists + // - local reference exists/not exists + // - force pull true/false + + t.Run("RemoteImageExists", func(t *testing.T) { + t.Parallel() + repo := testRegistry.CloneRepoForTest(t, "alpine") + imageRef := repo + ":latest" + + assertNoImageExists(t, dockerClient, imageRef) + + resp, err := dockerClient.Pull(t.Context(), imageRef, false) + require.NoError(t, err, "Failed to pull image %q", imageRef) + dockerHelper.CleanupImage(t, imageRef) + + assertImageExists(t, dockerClient, imageRef) + expectedResp := dockerHelper.InspectImage(t, imageRef) + // TODO[md]: we should check that the responsees are actually equal beyond the IDs. but atm + // the CLI and api are slightly different. The CLI leaves the descriptor field nil while the + // API response is populated. These should be identical on the new client, so we can change to EqualValues + assert.Equal(t, expectedResp.ID, resp.ID, "inspect response should match expected") + }) + + t.Run("RemoteReferenceNotFound", func(t *testing.T) { + t.Parallel() + imageRef := testRegistry.ImageRefForTest(t, "") + + assertNoImageExists(t, dockerClient, imageRef) + + resp, err := dockerClient.Pull(t.Context(), imageRef, false) + // TODO[md]: this might not be the right check. we probably want to wrap the error from the registry + // so we handle other failure cases, like failed auth, unknown tag, and unknown repo + require.Error(t, err, "Failed to pull image %q", imageRef) + assert.ErrorIs(t, err, &command.NotFoundError{Object: "manifest", Ref: imageRef}) + assert.Nil(t, resp, "inspect response should be nil") + }) + + t.Run("InvalidAuth", func(t *testing.T) { + t.Skip("skip auth tests until we're using the docker engine since we can't set auth on the host without side effects") + imageRef := testRegistry.ImageRefForTest(t, "") + + assertNoImageExists(t, dockerClient, imageRef) + + resp, err := dockerClient.Pull(t.Context(), imageRef, false) + // TODO[md]: this might not be the right check. we probably want to wrap the error from the registry + // so we handle other failure cases, like failed auth, unknown tag, and unknown repo + require.Error(t, err, "Failed to pull image %q", imageRef) + assert.ErrorContains(t, err, "failed to resolve reference") + assert.Nil(t, resp, "inspect response should be nil") + }) }) -} - -func (s *DockerClientSuite) runPullTests(t *testing.T) { - testRegistry := registry_testhelpers.StartTestRegistry(t) - // TODO[md]: add tests for the following permutations: - // - remote reference exists/not exists - // - local reference exists/not exists - // - force pull true/false + t.Run("ContainerStop", func(t *testing.T) { + t.Parallel() - t.Run("RemoteImageExists", func(t *testing.T) { - imageRef := testRegistry.ImageRefForTest(t, "") + t.Run("ContainerExistsAndIsRunning", func(t *testing.T) { + t.Parallel() + + container, err := testcontainers.Run( + t.Context(), + testRegistry.ImageRef("alpine:latest"), + testcontainers.WithCmd("sleep", "5000"), + ) + defer dockerHelper.CleanupImages(t) + defer testcontainers.CleanupContainer(t, container) + require.NoError(t, err, "Failed to run container") + + err = dockerClient.ContainerStop(t.Context(), container.ID) + require.NoError(t, err, "Failed to stop container %q", container.ID) + + state, err := container.State(t.Context()) + require.NoError(t, err, "Failed to get container state") + assert.Equal(t, state.Running, false) + }) + + t.Run("ContainerExistsAndIsNotRunning", func(t *testing.T) { + t.Parallel() + + container, err := testcontainers.GenericContainer(t.Context(), + testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: testRegistry.ImageRef("alpine:latest"), + Cmd: []string{"sleep", "5000"}, + }, + Started: false, + }, + ) + defer testcontainers.CleanupContainer(t, container) + containerID := container.GetContainerID() + require.NoError(t, err, "Failed to create container") + + err = dockerClient.ContainerStop(t.Context(), containerID) + require.NoError(t, err, "Failed to stop container %q", containerID) + + state, err := container.State(t.Context()) + require.NoError(t, err, "Failed to get container state") + assert.Equal(t, state.Running, false) + }) + + t.Run("ContainerDoesNotExist", func(t *testing.T) { + t.Parallel() + + err := dockerClient.ContainerStop(t.Context(), "containerid-that-does-not-exist") + require.ErrorIs(t, err, &command.NotFoundError{}) + require.ErrorContains(t, err, "container not found") + }) + }) - s.dockerHelper.LoadImageFixture(t, "alpine", imageRef) - s.dockerHelper.MustPushImage(t, imageRef) - s.dockerHelper.MustDeleteImage(t, imageRef) + t.Run("ContainerInspect", func(t *testing.T) { + t.Parallel() - s.assertNoImageExists(t, imageRef) + t.Run("ContainerExists", func(t *testing.T) { + t.Parallel() - resp, err := s.dockerClient.Pull(t.Context(), imageRef, false) - require.NoError(t, err, "Failed to pull image %q", imageRef) - s.dockerHelper.CleanupImage(t, imageRef) + container, err := testcontainers.Run( + t.Context(), + testRegistry.ImageRef("alpine:latest"), + testcontainers.WithCmd("sleep", "5000"), + ) + defer testcontainers.CleanupContainer(t, container) + require.NoError(t, err, "Failed to run container") - s.assertImageExists(t, imageRef) - expectedResp := s.dockerHelper.InspectImage(t, imageRef) - // TODO[md]: we should check that the responsees are actually equal beyond the IDs. but atm - // the CLI and api are slightly different. The CLI leaves the descriptor field nil while the - // API response is populated. These should be identical on the new client, so we can change to EqualValues - assert.Equal(t, expectedResp.ID, resp.ID, "inspect response should match expected") - }) + expected, err := container.Inspect(t.Context()) + require.NoError(t, err, "Failed to inspect container for expected response") - t.Run("RemoteReferenceNotFound", func(t *testing.T) { - imageRef := testRegistry.ImageRefForTest(t, "") + resp, err := dockerClient.ContainerInspect(t.Context(), container.ID) + require.NoError(t, err, "Failed to inspect container") + require.Equal(t, expected, resp) + }) - s.assertNoImageExists(t, imageRef) + t.Run("ContainerDoesNotExist", func(t *testing.T) { + t.Parallel() - resp, err := s.dockerClient.Pull(t.Context(), imageRef, false) - // TODO[md]: this might not be the right check. we probably want to wrap the error from the registry - // so we handle other failure cases, like failed auth, unknown tag, and unknown repo - require.Error(t, err, "Failed to pull image %q", imageRef) - assert.ErrorIs(t, err, &command.NotFoundError{Object: "manifest", Ref: imageRef}) - assert.Nil(t, resp, "inspect response should be nil") + _, err := dockerClient.ContainerInspect(t.Context(), "containerid-that-does-not-exist") + require.ErrorIs(t, err, &command.NotFoundError{}) + }) }) - t.Run("InvalidAuth", func(t *testing.T) { - t.Skip("skip auth tests until we're using the docker engine since we can't set auth on the host without side effects") - imageRef := testRegistry.ImageRefForTest(t, "") - - s.assertNoImageExists(t, imageRef) + t.Run("ContainerLogs", func(t *testing.T) { + t.Parallel() - resp, err := s.dockerClient.Pull(t.Context(), imageRef, false) - // TODO[md]: this might not be the right check. we probably want to wrap the error from the registry - // so we handle other failure cases, like failed auth, unknown tag, and unknown repo - require.Error(t, err, "Failed to pull image %q", imageRef) - assert.ErrorContains(t, err, "failed to resolve reference") - assert.Nil(t, resp, "inspect response should be nil") + t.Run("ContainerExistsAndIsRunning", func(t *testing.T) { + t.Parallel() + + container, err := testcontainers.Run( + t.Context(), + testRegistry.ImageRef("alpine:latest"), + // print "line $i" N times then exit, where $i is the line number + testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 5); do echo \"$i\"; sleep 1; done"), + // testcontainers.WithConfigModifier(func(config *container.Config) { + // config.Tty = true + // }), + ) + require.NoError(t, err, "Failed to run container") + defer testcontainers.CleanupContainer(t, container) + + var buf bytes.Buffer + err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf) + require.NoError(t, err, "Failed to get container logs") + + assert.Equal(t, "1\n2\n3\n4\n5\n", buf.String()) + }) + + t.Run("ContainerAlreadyStopped", func(t *testing.T) { + t.Parallel() + + container, err := testcontainers.Run( + t.Context(), + testRegistry.ImageRef("alpine:latest"), + testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 3); do echo \"$i\"; sleep 0.1; done"), + testcontainers.WithWaitStrategy(wait.ForExit()), + ) + require.NoError(t, err, "Failed to run container") + defer testcontainers.CleanupContainer(t, container) + + state, err := container.State(t.Context()) + require.NoError(t, err, "Failed to get container state") + assert.Equal(t, state.Running, false) + + var buf bytes.Buffer + err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf) + require.NoError(t, err, "Failed to get container logs") + + assert.Equal(t, "1\n2\n3\n", buf.String()) + }) + + t.Run("TTY and non-TTY streams match", func(t *testing.T) { + t.Parallel() + + runContainer := func(tty bool) string { + container, err := testcontainers.Run( + t.Context(), + testRegistry.ImageRef("alpine:latest"), + // print "line $i" N times then exit, where $i is the line number + testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 5); do echo \"$i\"; sleep 0.1; done"), + testcontainers.WithConfigModifier(func(config *container.Config) { + config.Tty = tty + }), + ) + require.NoError(t, err, "Failed to run container") + defer testcontainers.CleanupContainer(t, container) + + var buf bytes.Buffer + err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf) + require.NoError(t, err, "Failed to get container logs") + return buf.String() + } + + ttyOutput := runContainer(true) + nonTtyOutput := runContainer(false) + + // TTY uses CRLF for line endings, non-TTY uses LF. replace \r\n with \n so they match + ttyOutput = strings.ReplaceAll(ttyOutput, "\r\n", "\n") + + assert.Equal(t, ttyOutput, nonTtyOutput, "TTY and non-TTY streams should match after normalizing line endings") + }) + + t.Run("ContainerDoesNotExist", func(t *testing.T) { + t.Parallel() + + err := dockerClient.ContainerLogs(t.Context(), "containerid-that-does-not-exist", &bytes.Buffer{}) + require.ErrorIs(t, err, &command.NotFoundError{}) + }) }) -} -func (s *DockerClientSuite) runContainerStopTests(t *testing.T) { - t.Run("ContainerExistsAndIsRunning", func(t *testing.T) { + t.Run("Push", func(t *testing.T) { t.Parallel() - container, err := testcontainers.Run( - t.Context(), - "docker.io/library/busybox:latest", - testcontainers.WithCmd("sleep", "5000"), - ) - defer testcontainers.CleanupContainer(t, container) - require.NoError(t, err, "Failed to run container") + t.Run("valid image, valid registry", func(t *testing.T) { + t.Parallel() - err = s.dockerClient.ContainerStop(t.Context(), container.ID) - require.NoError(t, err, "Failed to stop container %q", container.ID) + ref := dockertest.NewRef(t).WithRegistry(testRegistry.RegistryHost()) - state, err := container.State(t.Context()) - require.NoError(t, err, "Failed to get container state") - assert.Equal(t, state.Running, false) - }) + dockerHelper.ImageFixture(t, "alpine", ref.String()) - t.Run("ContainerExistsAndIsNotRunning", func(t *testing.T) { - t.Parallel() + err := dockerClient.Push(t.Context(), ref.String()) + require.NoError(t, err) + assert.NoError(t, testRegistry.ImageExists(t, ref.String())) + }) - container, err := testcontainers.GenericContainer(t.Context(), - testcontainers.GenericContainerRequest{ - ContainerRequest: testcontainers.ContainerRequest{ - Image: "docker.io/library/busybox:latest", - Cmd: []string{"sleep", "5000"}, - }, - Started: false, - }, - ) - defer testcontainers.CleanupContainer(t, container) - containerID := container.GetContainerID() - require.NoError(t, err, "Failed to create container") - - err = s.dockerClient.ContainerStop(t.Context(), containerID) - require.NoError(t, err, "Failed to stop container %q", containerID) - - state, err := container.State(t.Context()) - require.NoError(t, err, "Failed to get container state") - assert.Equal(t, state.Running, false) - }) + t.Run("non-existent registry", func(t *testing.T) { + t.Parallel() - t.Run("ContainerDoesNotExist", func(t *testing.T) { - t.Parallel() + // start a local tcp server that immediately closes connections + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + // Create a reference to the mock registry + ref := dockertest.NewRef(t).WithRegistry(listener.Addr().String()) + dockerHelper.ImageFixture(t, "alpine", ref.String()) - err := s.dockerClient.ContainerStop(t.Context(), "containerid-that-does-not-exist") - require.ErrorIs(t, err, &command.NotFoundError{}) - require.ErrorContains(t, err, "container not found") + // Try to push to the mock registry + err = dockerClient.Push(t.Context(), ref.String()) + require.Error(t, err, "Push should fail with unreachable registry") + + // error message varies between dev and CI host environments, cover them all... + assert.Condition(t, func() bool { + msg := err.Error() + return strings.Contains(msg, "connection refused") || strings.Contains(msg, "EOF") + }, "Error should indicate registry is unreachable") + }) + + t.Run("missing image", func(t *testing.T) { + t.Parallel() + + ref := dockertest.NewRef(t).WithRegistry(testRegistry.RegistryHost()) + + err := dockerClient.Push(t.Context(), ref.String()) + assertNotFoundError(t, err, ref.String(), "tag") + }) + + t.Run("registry with authentication", func(t *testing.T) { + t.Parallel() + + if _, ok := dockerClient.(*DockerCommand); ok { + t.Skip("skipping auth tests for docker command client since we can't set auth on the host without side effects") + } + + authReg := registry_testhelpers.StartTestRegistry(t, registry_testhelpers.WithAuth("testuser", "testpass")) + + t.Run("correct credentials", func(t *testing.T) { + t.Parallel() + + ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost()) + dockerHelper.ImageFixture(t, "alpine", ref.String()) + + // create a new client with the correct auth config + authClient, err := NewClient(t.Context(), WithAuthConfig(registry.AuthConfig{ + Username: "testuser", + Password: "testpass", + ServerAddress: authReg.RegistryHost(), + })) + require.NoError(t, err) + + err = authClient.Push(t.Context(), ref.String()) + require.NoError(t, err, "Failed to push image to auth registry") + assert.NoError(t, authReg.ImageExists(t, ref.String())) + }) + + t.Run("missing auth", func(t *testing.T) { + t.Parallel() + + ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost()) + dockerHelper.ImageFixture(t, "alpine", ref.String()) + + // use root client which doesn't have auth setup + err := dockerClient.Push(t.Context(), ref.String()) + require.ErrorIs(t, err, command.ErrAuthorizationFailed) + }) + + t.Run("incorrect auth", func(t *testing.T) { + t.Parallel() + + ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost()) + dockerHelper.ImageFixture(t, "alpine", ref.String()) + + authClient, err := NewClient(t.Context(), WithAuthConfig(registry.AuthConfig{ + Username: "testuser", + Password: "wrongpass", + ServerAddress: authReg.RegistryHost(), + })) + require.NoError(t, err) + + err = authClient.Push(t.Context(), ref.String()) + require.ErrorIs(t, err, command.ErrAuthorizationFailed) + }) + + t.Run("correct credentials, not authorized", func(t *testing.T) { + t.Skip("skipping until the registry supports repo authorizations") + }) + }) }) } + +func assertImageExists(t *testing.T, dockerClient command.Command, imageRef string) { + t.Helper() + + inspect, err := dockerClient.Inspect(t.Context(), imageRef) + assert.NoError(t, err, "Failed to inspect image %q", imageRef) + assert.NotNil(t, inspect, "Image should exist") +} + +func assertNoImageExists(t *testing.T, dockerClient command.Command, imageRef string) { + t.Helper() + + inspect, err := dockerClient.Inspect(t.Context(), imageRef) + assert.ErrorIs(t, err, &command.NotFoundError{}, "Image should not exist") + assert.Nil(t, inspect, "Image should not exist") +} + +func assertNotFoundError(t *testing.T, err error, ref string, object string) { + t.Helper() + + var notFoundErr *command.NotFoundError + require.ErrorAs(t, err, ¬FoundErr, "should be a not found error") + require.Equal(t, ref, notFoundErr.Ref, "ref should match") + require.Equal(t, object, notFoundErr.Object, "object should match") +} diff --git a/pkg/docker/docker_command.go b/pkg/docker/docker_command.go index 621f78f23c..d81e0f4645 100644 --- a/pkg/docker/docker_command.go +++ b/pkg/docker/docker_command.go @@ -9,14 +9,10 @@ import ( "io" "os" "os/exec" - "path/filepath" "runtime" "strings" "github.com/creack/pty" - "github.com/docker/cli/cli/config" - "github.com/docker/cli/cli/config/configfile" - "github.com/docker/cli/cli/config/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" "github.com/mattn/go-isatty" @@ -77,80 +73,25 @@ func (c *DockerCommand) Pull(ctx context.Context, image string, force bool) (*im func (c *DockerCommand) Push(ctx context.Context, image string) error { console.Debugf("=== DockerCommand.Push %s", image) - return c.exec(ctx, nil, nil, nil, "", []string{"push", image}) -} - -func (c *DockerCommand) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { - console.Debugf("=== DockerCommand.LoadUserInformation %s", registryHost) - - conf := config.LoadDefaultConfigFile(os.Stderr) - credsStore := conf.CredentialsStore - if credsStore == "" { - authConf, err := loadAuthFromConfig(conf, registryHost) - if err != nil { - return nil, err - } - return &command.UserInfo{ - Token: authConf.Password, - Username: authConf.Username, - }, nil - } - credsHelper, err := loadAuthFromCredentialsStore(ctx, credsStore, registryHost) + err := c.exec(ctx, nil, nil, nil, "", []string{"push", image}) if err != nil { - return nil, err + if isTagNotFoundError(err) { + return &command.NotFoundError{Ref: image, Object: "tag"} + } + if isAuthorizationFailedError(err) { + return command.ErrAuthorizationFailed + } + return err } - return &command.UserInfo{ - Token: credsHelper.Secret, - Username: credsHelper.Username, - }, nil -} -func (c *DockerCommand) CreateTarFile(ctx context.Context, image string, tmpDir string, tarFile string, folder string) (string, error) { - console.Debugf("=== DockerCommand.CreateTarFile %s %s %s %s", image, tmpDir, tarFile, folder) - - args := []string{ - "run", - "--rm", - // force platform to linux/amd64 so darwin/arm64 outputs work in prod - "--platform", "linux/amd64", - "--volume", - tmpDir + ":/buildtmp", - image, - "/opt/r8/monobase/tar.sh", - "/buildtmp/" + tarFile, - "/", - folder, - } - if err := c.exec(ctx, nil, nil, nil, "", args); err != nil { - return "", err - } - return filepath.Join(tmpDir, tarFile), nil + return nil } -func (c *DockerCommand) CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) { - console.Debugf("=== DockerCommand.CreateAptTarFile %s %s", aptTarFile, packages) - - // This uses a hardcoded monobase image to produce an apt tar file. - // The reason being that this apt tar file is created outside the docker file, and it is created by - // running the apt.sh script on the monobase with the packages we intend to install, which produces - // a tar file that can be untarred into a docker build to achieve the equivalent of an apt-get install. - args := []string{ - "run", - "--rm", - // force platform to linux/amd64 so darwin/arm64 outputs work in prod - "--platform", "linux/amd64", - "--volume", - tmpDir + ":/buildtmp", - "r8.im/monobase:latest", - "/opt/r8/monobase/apt.sh", - "/buildtmp/" + aptTarFile, - } - args = append(args, packages...) - if err := c.exec(ctx, nil, nil, nil, "", args); err != nil { - return "", err - } +// TODO[md]: this doesn't need to be on the interface, move to auth handler +func (c *DockerCommand) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { + console.Debugf("=== DockerCommand.LoadUserInformation %s", registryHost) - return aptTarFile, nil + return loadUserInformation(ctx, registryHost) } func (c *DockerCommand) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) { @@ -162,7 +103,7 @@ func (c *DockerCommand) Inspect(ctx context.Context, ref string) (*image.Inspect } output, err := c.execCaptured(ctx, nil, "", args) if err != nil { - if strings.Contains(err.Error(), "No such image") { + if isImageNotFoundError(err) { return nil, &command.NotFoundError{Object: "image", Ref: ref} } return nil, err @@ -207,7 +148,14 @@ func (c *DockerCommand) ContainerLogs(ctx context.Context, containerID string, w "--follow", } - return c.exec(ctx, nil, w, nil, "", args) + err := c.exec(ctx, nil, w, nil, "", args) + if err != nil { + if isContainerNotFoundError(err) { + return &command.NotFoundError{Ref: containerID, Object: "container"} + } + return err + } + return err } func (c *DockerCommand) ContainerInspect(ctx context.Context, id string) (*container.InspectResponse, error) { @@ -221,7 +169,7 @@ func (c *DockerCommand) ContainerInspect(ctx context.Context, id string) (*conta output, err := c.execCaptured(ctx, nil, "", args) if err != nil { - if strings.Contains(err.Error(), "No such container") { + if isContainerNotFoundError(err) { return nil, &command.NotFoundError{Object: "container", Ref: id} } return nil, err @@ -249,7 +197,7 @@ func (c *DockerCommand) ContainerStop(ctx context.Context, containerID string) e } if err := c.exec(ctx, nil, io.Discard, nil, "", args); err != nil { - if strings.Contains(err.Error(), "No such container") { + if isContainerNotFoundError(err) { err = &command.NotFoundError{Object: "container", Ref: containerID} } return fmt.Errorf("failed to stop container %q: %w", containerID, err) @@ -417,7 +365,7 @@ func (c *DockerCommand) containerRun(ctx context.Context, options command.RunOpt err := c.exec(ctx, options.Stdin, options.Stdout, options.Stderr, "", args) if err != nil { - if strings.Contains(err.Error(), "could not select device driver") || strings.Contains(err.Error(), "nvidia-container-cli: initialization error") { + if isMissingDeviceDriverError(err) { return ErrMissingDeviceDriver } return err @@ -490,6 +438,12 @@ func (c *DockerCommand) exec(ctx context.Context, in io.Reader, outw, errw io.Wr if errors.Is(err, context.Canceled) { return err } + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + if !exitErr.Exited() && strings.Contains(exitErr.Error(), "signal: killed") { + return context.DeadlineExceeded + } + } return fmt.Errorf("command failed: %s: %w", stderrBuf.String(), err) } return nil @@ -503,50 +457,3 @@ func (c *DockerCommand) execCaptured(ctx context.Context, in io.Reader, dir stri } return out.String(), nil } - -func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types.AuthConfig, error) { - return conf.AuthConfigs[registryHost], nil -} - -func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) { - var out strings.Builder - binary := DockerCredentialBinary(credsStore) - cmd := exec.CommandContext(ctx, binary, "get") - cmd.Env = os.Environ() - cmd.Stdout = &out - cmd.Stderr = &out - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, err - } - defer stdin.Close() - console.Debug("$ " + strings.Join(cmd.Args, " ")) - err = cmd.Start() - if err != nil { - return nil, err - } - _, err = io.WriteString(stdin, registryHost) - if err != nil { - return nil, err - } - err = stdin.Close() - if err != nil { - return nil, err - } - err = cmd.Wait() - if err != nil { - return nil, fmt.Errorf("exec wait error: %w", err) - } - - var config CredentialHelperInput - err = json.Unmarshal([]byte(out.String()), &config) - if err != nil { - return nil, err - } - - return &config, nil -} - -func DockerCredentialBinary(credsStore string) string { - return "docker-credential-" + credsStore -} diff --git a/pkg/docker/dockertest/helper_client.go b/pkg/docker/dockertest/helper_client.go index 5e24c9e756..cd6f7c49ab 100644 --- a/pkg/docker/dockertest/helper_client.go +++ b/pkg/docker/dockertest/helper_client.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "slices" + "sync" "testing" "github.com/docker/docker/api/types/container" @@ -41,17 +42,33 @@ func NewHelperClient(t testing.TB) *HelperClient { t.Skip("Docker daemon is not running") } + helper := &HelperClient{ + Client: cli, + fixtures: make(map[string]*imageFixture), + mu: &sync.Mutex{}, + } + t.Cleanup(func() { + for _, img := range helper.fixtures { + _, err := helper.Client.ImageRemove(context.Background(), img.imageID, image.RemoveOptions{Force: true, PruneChildren: true}) + if err != nil { + t.Logf("Warning: Failed to remove image %q: %v", img.imageID, err) + } + } + if err := cli.Close(); err != nil { t.Fatalf("Failed to close Docker client: %v", err) } }) - return &HelperClient{Client: cli} + return helper } type HelperClient struct { Client *client.Client + + mu *sync.Mutex + fixtures map[string]*imageFixture } func (c *HelperClient) Close() error { @@ -210,10 +227,13 @@ func (c *HelperClient) CleanupImage(t testing.TB, imageRef string) { t.Helper() t.Cleanup(func() { - _, _ = c.Client.ImageRemove(context.Background(), imageRef, image.RemoveOptions{ + _, err := c.Client.ImageRemove(context.Background(), imageRef, image.RemoveOptions{ Force: true, PruneChildren: true, }) + if err != nil { + t.Logf("Warning: Failed to remove image %q: %v", imageRef, err) + } }) } @@ -255,9 +275,32 @@ func (c *HelperClient) InspectContainer(t testing.TB, containerID string) *conta return &inspect } -func (c *HelperClient) LoadImageFixture(t testing.TB, name string, tag string) { +func (c *HelperClient) ImageFixture(t testing.TB, name string, tag string) { + t.Helper() + fixture := c.loadImageFixture(t, name) + + t.Logf("Tagging image fixture %q with %q", fixture.ref, tag) + if err := c.Client.ImageTag(t.Context(), fixture.imageID, tag); err != nil { + require.NoError(t, err, "Failed to tag image %q with %q: %v", fixture.ref, tag, err) + } + // remove the image when the test is done + t.Cleanup(func() { + _, _ = c.Client.ImageRemove(context.Background(), tag, image.RemoveOptions{Force: true}) + }) +} + +func (c *HelperClient) loadImageFixture(t testing.TB, name string) *imageFixture { t.Helper() + c.mu.Lock() + defer c.mu.Unlock() + + ref := fmt.Sprintf("cog-test-fixture:%s", name) + + if fixture, ok := c.fixtures[ref]; ok { + return fixture + } + // Get the path of the current file _, filename, _, ok := runtime.Caller(0) if !ok { @@ -270,7 +313,6 @@ func (c *HelperClient) LoadImageFixture(t testing.TB, name string, tag string) { // Construct the path to the fixture fixturePath := filepath.Join(dir, "testdata", name+".tar") - ref := fmt.Sprintf("cog-test-fixture:%s", name) t.Logf("Loading image fixture %q from %s", ref, fixturePath) f, err := os.Open(fixturePath) @@ -280,22 +322,23 @@ func (c *HelperClient) LoadImageFixture(t testing.TB, name string, tag string) { l, err := c.Client.ImageLoad(t.Context(), f) require.NoError(t, err, "Failed to load fixture %q", name) defer l.Body.Close() - _, err = io.Copy(os.Stdout, l.Body) + _, err = io.Copy(os.Stderr, l.Body) require.NoError(t, err, "Failed to copy fixture %q", name) - // remove the image when the test is done - t.Cleanup(func() { - _, _ = c.Client.ImageRemove(context.Background(), ref, image.RemoveOptions{}) - }) + inspect, err := c.Client.ImageInspect(t.Context(), ref) + require.NoError(t, err, "Failed to inspect image %q", ref) - if tag != "" { - t.Logf("Tagging image fixture %q with %q", ref, tag) - if err := c.Client.ImageTag(t.Context(), ref, tag); err != nil { - require.NoError(t, err, "Failed to tag image %q with %q: %v", ref, tag, err) - } - // remove the image when the test is done - t.Cleanup(func() { - _, _ = c.Client.ImageRemove(context.Background(), tag, image.RemoveOptions{}) - }) + fixture := &imageFixture{ + ref: ref, + imageID: inspect.ID, } + + c.fixtures[ref] = fixture + + return fixture +} + +type imageFixture struct { + imageID string + ref string } diff --git a/pkg/docker/dockertest/mock_command.go b/pkg/docker/dockertest/mock_command.go index e27c715ed2..5b399d9aab 100644 --- a/pkg/docker/dockertest/mock_command.go +++ b/pkg/docker/dockertest/mock_command.go @@ -5,6 +5,7 @@ import ( "io" "os" "path/filepath" + "strings" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" @@ -99,6 +100,15 @@ func (c *MockCommand) ImageBuild(ctx context.Context, options command.ImageBuild } func (c *MockCommand) Run(ctx context.Context, options command.RunOptions) error { + // hack to handle generating tar files for monobase + if options.Args[0] == "/opt/r8/monobase/tar.sh" || options.Args[0] == "/opt/r8/monobase/apt.sh" { + tmpDir := options.Volumes[0].Source + tarfile := strings.TrimPrefix(options.Args[1], "/buildtmp/") + + outPath := filepath.Join(tmpDir, tarfile) + return os.WriteFile(outPath, []byte("hello\ngo\n"), 0o644) + } + panic("not implemented") } diff --git a/pkg/docker/dockertest/ref.go b/pkg/docker/dockertest/ref.go new file mode 100644 index 0000000000..4b7d0fe9e3 --- /dev/null +++ b/pkg/docker/dockertest/ref.go @@ -0,0 +1,82 @@ +package dockertest + +import ( + "strings" + "testing" + + "github.com/google/go-containerregistry/pkg/name" + "github.com/stretchr/testify/require" +) + +type Ref struct { + t *testing.T + ref name.Reference +} + +func NewRef(t *testing.T) Ref { + t.Helper() + + repoName := strings.ToLower(t.Name()) + // Replace any characters that aren't valid in a docker image repo name with underscore + // Valid characters are: a-z, 0-9, ., _, -, / + repoName = strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' || r == '/' { + return r + } + return '_' + }, repoName) + + ref, err := name.ParseReference(repoName, name.WithDefaultRegistry("")) + require.NoError(t, err, "Failed to create reference for test") + + return Ref{t: t, ref: ref} +} + +func (r Ref) WithTag(tagName string) Ref { + tagRef := r.ref.Context().Tag(tagName) + return Ref{t: r.t, ref: tagRef} +} + +func (r Ref) WithDigest(digest string) Ref { + digestRef := r.ref.Context().Digest(digest) + return Ref{t: r.t, ref: digestRef} +} + +func (r Ref) WithRegistry(registry string) Ref { + reg, err := name.NewRegistry(registry) + require.NoError(r.t, err, "Failed to create registry for test") + + repo := r.ref.Context() + repo.Registry = reg + var newRef name.Reference + switch r.ref.(type) { + case name.Tag: + newRef = repo.Tag(r.ref.Identifier()) + case name.Digest: + newRef = repo.Digest(r.ref.Identifier()) + default: + require.Fail(r.t, "Unsupported reference type") + } + + return Ref{t: r.t, ref: newRef} +} + +func (r Ref) WithoutRegistry() Ref { + repo := r.ref.Context() + repo.Registry = name.Registry{} + var newRef name.Reference + switch r.ref.(type) { + case name.Tag: + newRef = repo.Tag(r.ref.Identifier()) + case name.Digest: + newRef = repo.Digest(r.ref.Identifier()) + default: + require.Fail(r.t, "Unsupported reference type") + } + + return Ref{t: r.t, ref: newRef} +} + +func (r Ref) String() string { + return r.ref.Name() +} diff --git a/pkg/docker/dockertest/ref_test.go b/pkg/docker/dockertest/ref_test.go new file mode 100644 index 0000000000..754b12dec8 --- /dev/null +++ b/pkg/docker/dockertest/ref_test.go @@ -0,0 +1,24 @@ +package dockertest + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRef(t *testing.T) { + ref := NewRef(t) + assert.Equal(t, "testref:latest", ref.String()) + + ref = ref.WithTag("v2") + assert.Equal(t, "testref:v2", ref.String()) + + ref = ref.WithRegistry("r8.im") + assert.Equal(t, "r8.im/testref:v2", ref.String()) + + ref = ref.WithoutRegistry() + assert.Equal(t, "testref:v2", ref.String()) + + ref = ref.WithDigest("sha256:71859b0c62df47efaeae4f93698b56a8dddafbf041778fd668bbd1ab45a864f8") + assert.Equal(t, "testref@sha256:71859b0c62df47efaeae4f93698b56a8dddafbf041778fd668bbd1ab45a864f8", ref.String()) +} diff --git a/pkg/docker/errors.go b/pkg/docker/errors.go new file mode 100644 index 0000000000..feac88a465 --- /dev/null +++ b/pkg/docker/errors.go @@ -0,0 +1,49 @@ +package docker + +import "strings" + +// Error messages vary between different backends (dockerd, containerd, podman, orbstack, etc) or even versions of docker. +// These helpers normalize the check so callers can handle situations without worrying about the underlying implementation. +// Yes, it's gross, but whattaya gonna do + +func isTagNotFoundError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "tag does not exist") || + strings.Contains(msg, "An image does not exist locally with the tag") +} + +func isImageNotFoundError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "image does not exist") || + strings.Contains(msg, "No such image") +} + +func isContainerNotFoundError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "container does not exist") || + strings.Contains(msg, "No such container") +} + +func isAuthorizationFailedError(err error) bool { + msg := err.Error() + + // registry requires auth and none were provided + if strings.Contains(msg, "no basic auth credentials") { + return true + } + + // registry rejected the provided auth + if strings.Contains(msg, "authorization failed") || + strings.Contains(msg, "401 Unauthorized") || + strings.Contains(msg, "unauthorized: authentication required") { + return true + } + + return false +} + +func isMissingDeviceDriverError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "could not select device driver") || + strings.Contains(msg, "nvidia-container-cli: initialization error") +} diff --git a/pkg/docker/fast_push.go b/pkg/docker/fast_push.go index f67bd92766..b002c57d27 100644 --- a/pkg/docker/fast_push.go +++ b/pkg/docker/fast_push.go @@ -147,11 +147,11 @@ func FastPush(ctx context.Context, image string, projectDir string, command comm } func createPythonPackagesTarFile(ctx context.Context, image string, tmpDir string, command command.Command) (string, error) { - return command.CreateTarFile(ctx, image, tmpDir, requirementsTarFile, "root/.venv") + return CreateTarFile(ctx, command, image, tmpDir, requirementsTarFile, "root/.venv") } func createSrcTarFile(ctx context.Context, image string, tmpDir string, command command.Command) (string, error) { - return command.CreateTarFile(ctx, image, tmpDir, "src.tar.zst", "src") + return CreateTarFile(ctx, command, image, tmpDir, "src.tar.zst", "src") } func createWeightsFilesFromWeightsManifest(weights []weights.Weight) []web.File { diff --git a/pkg/docker/login.go b/pkg/docker/login.go index e17f2e123c..c9f5b18706 100644 --- a/pkg/docker/login.go +++ b/pkg/docker/login.go @@ -37,7 +37,7 @@ func saveAuthToConfig(conf *configfile.ConfigFile, registryHost string, username } func saveAuthToCredentialsStore(ctx context.Context, credsStore string, registryHost string, username string, token string) error { - binary := DockerCredentialBinary(credsStore) + binary := dockerCredentialBinary(credsStore) input := CredentialHelperInput{ Username: username, Secret: token, diff --git a/pkg/docker/monobase.go b/pkg/docker/monobase.go new file mode 100644 index 0000000000..f604a65359 --- /dev/null +++ b/pkg/docker/monobase.go @@ -0,0 +1,141 @@ +package docker + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "os" + "path" + "path/filepath" + "sort" + "strings" + + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/util/console" +) + +const aptTarballPrefix = "apt." +const aptTarballSuffix = ".tar.zst" + +func CreateAptTarball(ctx context.Context, tmpDir string, dockerClient command.Command, packages ...string) (string, error) { + if len(packages) > 0 { + sort.Strings(packages) + hash := sha256.New() + hash.Write([]byte(strings.Join(packages, " "))) + hexHash := hex.EncodeToString(hash.Sum(nil)) + aptTarFile := aptTarballPrefix + hexHash + aptTarballSuffix + aptTarPath := filepath.Join(tmpDir, aptTarFile) + + if _, err := os.Stat(aptTarPath); errors.Is(err, os.ErrNotExist) { + // Remove previous apt tar files. + err = removeAptTarballs(tmpDir) + if err != nil { + return "", err + } + + // Create the apt tar file + _, err = CreateAptTarFile(ctx, dockerClient, tmpDir, aptTarFile, packages...) + if err != nil { + return "", err + } + } + + return aptTarFile, nil + } + return "", nil +} + +func CurrentAptTarball(tmpDir string) (string, error) { + files, err := os.ReadDir(tmpDir) + if err != nil { + return "", fmt.Errorf("os read dir error: %w", err) + } + + for _, file := range files { + fileName := file.Name() + if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) { + return filepath.Join(tmpDir, fileName), nil + } + } + + return "", nil +} + +func removeAptTarballs(tmpDir string) error { + files, err := os.ReadDir(tmpDir) + if err != nil { + return err + } + + for _, file := range files { + fileName := file.Name() + if strings.HasPrefix(fileName, aptTarballPrefix) && strings.HasSuffix(fileName, aptTarballSuffix) { + err = os.Remove(filepath.Join(tmpDir, fileName)) + if err != nil { + return err + } + } + } + + return nil +} + +func CreateTarFile(ctx context.Context, dockerClient command.Command, image string, tmpDir string, tarFile string, folder string) (string, error) { + console.Debugf("=== CreateTarFile %s %s %s %s", image, tmpDir, tarFile, folder) + + opts := command.RunOptions{ + Image: image, + Args: []string{ + "/opt/r8/monobase/tar.sh", + path.Join("/buildtmp", tarFile), + "/", + folder, + }, + Volumes: []command.Volume{ + { + Source: tmpDir, + Destination: "/buildtmp", + }, + }, + } + + if err := dockerClient.Run(ctx, opts); err != nil { + return "", err + } + + return filepath.Join(tmpDir, tarFile), nil +} + +func CreateAptTarFile(ctx context.Context, dockerClient command.Command, tmpDir string, aptTarFile string, packages ...string) (string, error) { + console.Debugf("=== CreateAptTarFile %s %s", aptTarFile, packages) + + // This uses a hardcoded monobase image to produce an apt tar file. + // The reason being that this apt tar file is created outside the docker file, and it is created by + // running the apt.sh script on the monobase with the packages we intend to install, which produces + // a tar file that can be untarred into a docker build to achieve the equivalent of an apt-get install. + + opts := command.RunOptions{ + Image: "r8.im/monobase:latest", + Args: append( + []string{ + "/opt/r8/monobase/apt.sh", + path.Join("/buildtmp", aptTarFile), + }, + packages..., + ), + Volumes: []command.Volume{ + { + Source: tmpDir, + Destination: "/buildtmp", + }, + }, + } + + if err := dockerClient.Run(ctx, opts); err != nil { + return "", err + } + + return aptTarFile, nil +} diff --git a/pkg/docker/apt_test.go b/pkg/docker/monobase_test.go similarity index 100% rename from pkg/docker/apt_test.go rename to pkg/docker/monobase_test.go diff --git a/pkg/registry_testhelpers/registry_container.go b/pkg/registry_testhelpers/registry_container.go index c64e0e2094..bc0cfb3521 100644 --- a/pkg/registry_testhelpers/registry_container.go +++ b/pkg/registry_testhelpers/registry_container.go @@ -12,10 +12,17 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/crane" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/registry" "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/crypto/bcrypt" + + dockerregistry "github.com/docker/docker/api/types/registry" "github.com/replicate/cog/pkg/util" ) @@ -25,23 +32,26 @@ import ( // that can be used to inspect the registry and generate absolute image references. It will // automatically be cleaned when the test finishes. // This is safe to run concurrently across multiple tests. -func StartTestRegistry(t *testing.T) *RegistryContainer { +func StartTestRegistry(t *testing.T, opts ...Option) *RegistryContainer { t.Helper() + options := &options{} + for _, opt := range opts { + opt(options) + } + _, filename, _, _ := runtime.Caller(0) testdataDir := filepath.Join(filepath.Dir(filename), "testdata", "docker") - registryContainer, err := registry.Run( - t.Context(), - "registry:3", + containerCustomizers := []testcontainers.ContainerCustomizer{ testcontainers.WithFiles(testcontainers.ContainerFile{ HostFilePath: testdataDir, ContainerFilePath: "/var/lib/registry/", FileMode: 0o755, }), testcontainers.WithWaitStrategy( - wait.ForHTTP("/v2/").WithPort("5000/tcp"). - WithStartupTimeout(10*time.Second), + wait.ForHTTP("/").WithPort("5000/tcp"). + WithStartupTimeout(10 * time.Second), ), testcontainers.WithHostConfigModifier(func(hostConfig *container.HostConfig) { // docker only considers localhost:1 through localhost:9999 as insecure. testcontainers @@ -53,15 +63,33 @@ func StartTestRegistry(t *testing.T) *RegistryContainer { nat.Port("5000/tcp"): {{HostIP: "0.0.0.0", HostPort: strconv.Itoa(port)}}, } }), + } + + if options.auth != nil { + htpasswd, err := generateHtpasswd(options.auth.Username, options.auth.Password) + require.NoError(t, err) + containerCustomizers = append(containerCustomizers, + registry.WithHtpasswd(htpasswd), + ) + } + + registryContainer, err := registry.Run( + t.Context(), + "registry:3", + containerCustomizers..., ) defer testcontainers.CleanupContainer(t, registryContainer) require.NoError(t, err, "Failed to start registry container") - return &RegistryContainer{Container: registryContainer} + return &RegistryContainer{ + Container: registryContainer, + options: options, + } } type RegistryContainer struct { Container *registry.RegistryContainer + options *options } func (c *RegistryContainer) ImageRef(ref string) string { @@ -75,3 +103,59 @@ func (c *RegistryContainer) ImageRefForTest(t *testing.T, label string) string { repo := strings.ToLower(t.Name()) return c.ImageRef(fmt.Sprintf("%s:%s", repo, label)) } + +func (c *RegistryContainer) CloneRepo(t *testing.T, existingRepo, newRepo string) string { + existingRepo = c.ImageRef(existingRepo) + newRepo = c.ImageRef(newRepo) + + err := crane.CopyRepository(existingRepo, newRepo) + require.NoError(t, err, "Failed to clone repo %q to %q", existingRepo, newRepo) + return newRepo +} + +func (c *RegistryContainer) CloneRepoForTest(t *testing.T, repo string) string { + return c.CloneRepo(t, repo, strings.ToLower(t.Name())) +} + +func (c *RegistryContainer) ImageExists(t *testing.T, ref string) error { + parsedRef, err := name.ParseReference(ref, name.WithDefaultRegistry(c.RegistryHost())) + require.NoError(t, err) + + var opts []remote.Option + + if c.options.auth != nil { + opts = append(opts, remote.WithAuth(authn.FromConfig(authn.AuthConfig{ + Username: c.options.auth.Username, + Password: c.options.auth.Password, + }))) + } + _, err = remote.Head(parsedRef, opts...) + return err +} + +func (c *RegistryContainer) RegistryHost() string { + return c.Container.RegistryName +} + +type Option func(*options) + +func WithAuth(username, password string) func(*options) { + return func(o *options) { + o.auth = &dockerregistry.AuthConfig{ + Username: username, + Password: password, + } + } +} + +type options struct { + auth *dockerregistry.AuthConfig +} + +func generateHtpasswd(username, password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return fmt.Sprintf("%s:%s", username, string(hash)), nil +} diff --git a/test-integration/test_integration/fixtures/secrets-project/cog.yaml b/test-integration/test_integration/fixtures/secrets-project/cog.yaml new file mode 100644 index 0000000000..01d1b26606 --- /dev/null +++ b/test-integration/test_integration/fixtures/secrets-project/cog.yaml @@ -0,0 +1,24 @@ +build: + run: + # assert that the file secret of file-secret.txt on the host is written to the target file and has the expected value + - command: >- + ID="file-secret"; + EXPECTED_VALUE="file_secret_value"; + EXPECTED_PATH="/etc/file_secret.txt"; + [ "$(cat "$EXPECTED_PATH")" = "$EXPECTED_VALUE" ] || ( echo "Assertion failed \"$EXPECTED_PATH\" \"$(cat $EXPECTED_PATH)\" != \"$EXPECTED_VALUE\""; exit 1; ) + mounts: + - type: secret + id: file-secret + target: /etc/file_secret.txt + # assert that the env secret of $ENV_SECRET on the host is written to the target file and has the expected value + - command: >- + ID="env-secret"; + EXPECTED_VALUE="env_secret_value"; + EXPECTED_PATH="/var/env-secret.txt"; + [ "$(cat "$EXPECTED_PATH")" = "$EXPECTED_VALUE" ] || ( echo "Assertion failed \"$EXPECTED_PATH\" \"$(cat $EXPECTED_PATH)\" != \"$EXPECTED_VALUE\""; exit 1; ) + mounts: + - type: secret + id: env-secret + target: /var/env-secret.txt + +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/secrets-project/file-secret.txt b/test-integration/test_integration/fixtures/secrets-project/file-secret.txt new file mode 100644 index 0000000000..dd7af47dbc --- /dev/null +++ b/test-integration/test_integration/fixtures/secrets-project/file-secret.txt @@ -0,0 +1 @@ +file_secret_value diff --git a/test-integration/test_integration/fixtures/secrets-project/predict.py b/test-integration/test_integration/fixtures/secrets-project/predict.py new file mode 100644 index 0000000000..95a24e7178 --- /dev/null +++ b/test-integration/test_integration/fixtures/secrets-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, num: int) -> int: + return num * 2 diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index 30ce90d201..8c1ac12750 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -534,3 +534,24 @@ def test_install_requires_packaging(docker_image, cog_binary): ) print(build_process.stderr.decode()) assert build_process.returncode == 0 + + +def test_secrets(tmpdir_factory, docker_image, cog_binary): + project_dir = Path(__file__).parent / "fixtures/secrets-project" + + build_process = subprocess.run( + [ + cog_binary, + "build", + "-t", + docker_image, + "--secret", + "id=file-secret,src=file-secret.txt", + "--secret", + "id=env-secret,env=ENV_SECRET", + ], + cwd=project_dir, + capture_output=True, + env={**os.environ, "ENV_SECRET": "env_secret_value"}, + ) + assert build_process.returncode == 0