diff --git a/pkg/cli/baseimage.go b/pkg/cli/baseimage.go index 29abf9cc35..6d6d31de16 100644 --- a/pkg/cli/baseimage.go +++ b/pkg/cli/baseimage.go @@ -33,7 +33,7 @@ func NewBaseImageRootCommand() (*cobra.Command, error) { console.SetLevel(console.DebugLevel) } cmd.SilenceUsage = true - if err := update.DisplayAndCheckForRelease(); err != nil { + if err := update.DisplayAndCheckForRelease(cmd.Context()); err != nil { console.Debugf("%s", err) } }, @@ -87,7 +87,7 @@ func newBaseImageDockerfileCommand() *cobra.Command { if err != nil { return err } - dockerfile, err := generator.GenerateDockerfile() + dockerfile, err := generator.GenerateDockerfile(cmd.Context()) if err != nil { return err } @@ -108,11 +108,13 @@ func newBaseImageBuildCommand() *cobra.Command { Use: "build", Short: "Build Cog base image", RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + generator, err := baseImageGeneratorFromFlags() if err != nil { return err } - dockerfileContents, err := generator.GenerateDockerfile() + dockerfileContents, err := generator.GenerateDockerfile(ctx) if err != nil { return err } @@ -123,7 +125,7 @@ func newBaseImageBuildCommand() *cobra.Command { } baseImageName := dockerfile.BaseImageName(baseImageCUDAVersion, baseImagePythonVersion, baseImageTorchVersion) - err = docker.Build(cwd, dockerfileContents, baseImageName, []string{}, buildNoCache, buildProgressOutput, config.BuildSourceEpochTimestamp, dockercontext.StandardBuildDirectory, nil) + err = docker.Build(ctx, cwd, dockerfileContents, baseImageName, []string{}, buildNoCache, buildProgressOutput, config.BuildSourceEpochTimestamp, dockercontext.StandardBuildDirectory, nil) if err != nil { return err } diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 34277ffda8..45b530e954 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -55,6 +55,8 @@ func newBuildCommand() *cobra.Command { } func buildCommand(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cfg, projectDir, err := config.GetConfig(projectDirFlag) if err != nil { return err @@ -76,7 +78,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, nil, buildLocalImage); err != nil { + if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, nil, buildLocalImage); err != nil { return err } diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index f26c1c36b6..76c96b7935 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -35,6 +35,8 @@ func newDebugCommand() *cobra.Command { } func cmdDockerfile(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cfg, projectDir, err := config.GetConfig(projectDirFlag) if err != nil { return err @@ -62,7 +64,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error { imageName = config.DockerImageName(projectDir) } - weightsDockerfile, RunnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(imageName) + weightsDockerfile, RunnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(ctx, imageName) if err != nil { return err } @@ -71,7 +73,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error { console.Output(fmt.Sprintf("=== Runner Dockerfile contents:\n%s\n===\n", RunnerDockerfile)) console.Output(fmt.Sprintf("=== DockerIgnore contents:\n%s===\n", dockerignore)) } else { - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(ctx) if err != nil { return err } diff --git a/pkg/cli/login.go b/pkg/cli/login.go index f952d705bb..fa1c6b47cf 100644 --- a/pkg/cli/login.go +++ b/pkg/cli/login.go @@ -41,6 +41,8 @@ func newLoginCommand() *cobra.Command { } func login(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + registryHost, err := cmd.Flags().GetString("registry") if err != nil { return err @@ -73,7 +75,7 @@ func login(cmd *cobra.Command, args []string) error { return err } - if err := docker.SaveLoginToken(registryHost, username, token); err != nil { + if err := docker.SaveLoginToken(ctx, registryHost, username, token); err != nil { return err } diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 0359e441f4..f5e6a45b39 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -2,6 +2,7 @@ package cli import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -66,6 +67,8 @@ the prediction on that.`, } func cmdPredict(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + imageName := "" volumes := []docker.Volume{} gpus := gpusFlag @@ -85,7 +88,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { if buildFast { imageName = config.DockerImageName(projectDir) } else { - if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { + if imageName, err = image.BuildBase(ctx, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { return err } @@ -109,17 +112,17 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return fmt.Errorf("Invalid image name '%s'. Did you forget `-i`?", imageName) } - exists, err := docker.ImageExists(imageName) + exists, err := docker.ImageExists(ctx, imageName) if err != nil { return fmt.Errorf("Failed to determine if %s exists: %w", imageName, err) } if !exists { console.Infof("Pulling image: %s", imageName) - if err := docker.Pull(imageName); err != nil { + if err := docker.Pull(ctx, imageName); err != nil { return fmt.Errorf("Failed to pull %s: %w", imageName, err) } } - conf, err := image.GetConfig(imageName) + conf, err := image.GetConfig(ctx, imageName) if err != nil { return err } @@ -135,7 +138,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Infof("Starting Docker image %s and running setup()...", imageName) dockerCommand := docker.NewDockerCommand() - predictor, err := predict.NewPredictor(docker.RunOptions{ + predictor, err := predict.NewPredictor(ctx, docker.RunOptions{ GPUs: gpus, Image: imageName, Volumes: volumes, @@ -152,20 +155,20 @@ func cmdPredict(cmd *cobra.Command, args []string) error { <-captureSignal console.Info("Stopping container...") - if err := predictor.Stop(); err != nil { + if err := predictor.Stop(ctx); err != nil { console.Warnf("Failed to stop container: %s", err) } }() timeout := time.Duration(setupTimeout) * time.Second - if err := predictor.Start(os.Stderr, timeout); err != nil { + if err := predictor.Start(ctx, os.Stderr, timeout); err != nil { // Only retry if we're using a GPU but but the user didn't explicitly select a GPU with --gpus // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it if gpus == "all" && errors.Is(err, docker.ErrMissingDeviceDriver) { console.Info("Missing device driver, re-trying without GPU") - _ = predictor.Stop() - predictor, err = predict.NewPredictor(docker.RunOptions{ + _ = predictor.Stop(ctx) + predictor, err = predict.NewPredictor(ctx, docker.RunOptions{ Image: imageName, Volumes: volumes, Env: envFlags, @@ -174,7 +177,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - if err := predictor.Start(os.Stderr, timeout); err != nil { + if err := predictor.Start(ctx, os.Stderr, timeout); err != nil { return err } } else { @@ -185,7 +188,8 @@ func cmdPredict(cmd *cobra.Command, args []string) error { // FIXME: will not run on signal defer func() { console.Debugf("Stopping container...") - if err := predictor.Stop(); err != nil { + // use background context to ensure stop signal is still sent after root context is canceled + if err := predictor.Stop(context.Background()); err != nil { console.Warnf("Failed to stop container: %s", err) } }() diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 15d9364047..c2aba5afea 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -42,6 +42,8 @@ func newPushCommand() *cobra.Command { } func push(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cfg, projectDir, err := config.GetConfig(projectDirFlag) if err != nil { return err @@ -61,7 +63,7 @@ func push(cmd *cobra.Command, args []string) error { replicatePrefix := fmt.Sprintf("%s/", global.ReplicateRegistryHost) if strings.HasPrefix(imageName, replicatePrefix) { - if err := docker.ManifestInspect(imageName); err != nil && strings.Contains(err.Error(), `"code":"NAME_UNKNOWN"`) { + if err := docker.ManifestInspect(ctx, imageName); err != nil && strings.Contains(err.Error(), `"code":"NAME_UNKNOWN"`) { return fmt.Errorf("Unable to find Replicate existing model for %s. Go to replicate.com and create a new model before pushing.", imageName) } } else { @@ -81,7 +83,7 @@ func push(cmd *cobra.Command, args []string) error { startBuildTime := time.Now() - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, annotations, buildLocalImage); err != nil { + if err := image.Build(ctx, cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast, annotations, buildLocalImage); err != nil { return err } @@ -93,7 +95,7 @@ func push(cmd *cobra.Command, args []string) error { } command := docker.NewDockerCommand() - err = docker.Push(imageName, buildFast, projectDir, command, docker.BuildInfo{ + err = docker.Push(ctx, imageName, buildFast, projectDir, command, docker.BuildInfo{ BuildTime: buildDuration, BuildID: buildID.String(), }) diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 11e67edfa0..1f55d39cde 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -29,7 +29,7 @@ https://github.com/replicate/cog`, console.SetLevel(console.DebugLevel) } cmd.SilenceUsage = true - if err := update.DisplayAndCheckForRelease(); err != nil { + if err := update.DisplayAndCheckForRelease(cmd.Context()); err != nil { console.Debugf("%s", err) } }, diff --git a/pkg/cli/run.go b/pkg/cli/run.go index d1fed0e4a1..3636e27c53 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -52,11 +52,13 @@ func newRunCommand() *cobra.Command { } func run(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + cfg, projectDir, err := config.GetConfig(projectDirFlag) if err != nil { return err } - imageName, err := image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput) + imageName, err := image.BuildBase(ctx, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput) if err != nil { return err } @@ -78,7 +80,7 @@ func run(cmd *cobra.Command, args []string) error { Volumes: []docker.Volume{{Source: projectDir, Destination: "/src"}}, Workdir: "/src", } - runOptions, err = docker.FillInWeightsManifestVolumes(dockerCommand, runOptions) + runOptions, err = docker.FillInWeightsManifestVolumes(ctx, dockerCommand, runOptions) if err != nil { return err } @@ -103,14 +105,14 @@ func run(cmd *cobra.Command, args []string) error { console.Info("Fast run enabled.") } - err = docker.Run(runOptions) + err = docker.Run(ctx, runOptions) // Only retry if we're using a GPU but but the user didn't explicitly select a GPU with --gpus // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it if runOptions.GPUs == "all" && err == docker.ErrMissingDeviceDriver { console.Info("Missing device driver, re-trying without GPU") runOptions.GPUs = "" - err = docker.Run(runOptions) + err = docker.Run(ctx, runOptions) } return err diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index da0a963a4b..662d8e7e31 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -41,12 +41,14 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` } func cmdServe(cmd *cobra.Command, arg []string) error { + ctx := cmd.Context() + cfg, projectDir, err := config.GetConfig(projectDirFlag) if err != nil { return err } - imageName, err := image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput) + imageName, err := image.BuildBase(ctx, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput) if err != nil { return err } @@ -78,7 +80,7 @@ func cmdServe(cmd *cobra.Command, arg []string) error { Volumes: []docker.Volume{{Source: projectDir, Destination: "/src"}}, Workdir: "/src", } - runOptions, err = docker.FillInWeightsManifestVolumes(dockerCommand, runOptions) + runOptions, err = docker.FillInWeightsManifestVolumes(ctx, dockerCommand, runOptions) if err != nil { return err } @@ -95,14 +97,14 @@ func cmdServe(cmd *cobra.Command, arg []string) error { console.Infof("Serving at http://127.0.0.1:%[1]v", port) console.Info("") - err = docker.Run(runOptions) + err = docker.Run(ctx, runOptions) // Only retry if we're using a GPU but but the user didn't explicitly select a GPU with --gpus // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it if runOptions.GPUs == "all" && err == docker.ErrMissingDeviceDriver { console.Info("Missing device driver, re-trying without GPU") runOptions.GPUs = "" - err = docker.Run(runOptions) + err = docker.Run(ctx, runOptions) } return err diff --git a/pkg/cli/train.go b/pkg/cli/train.go index ef7f14f368..cea54b0684 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "os" "os/signal" @@ -52,6 +53,8 @@ Otherwise, it will build the model in the current directory and train it.`, } func cmdTrain(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + imageName := "" volumes := []docker.Volume{} gpus := gpusFlag @@ -68,7 +71,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { buildFast = cfg.Build.Fast } - if imageName, err = image.BuildBase(cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { + if imageName, err = image.BuildBase(ctx, cfg, projectDir, buildUseCudaBaseImage, DetermineUseCogBaseImage(cmd), buildProgressOutput); err != nil { return err } @@ -85,17 +88,17 @@ func cmdTrain(cmd *cobra.Command, args []string) error { // Use existing image imageName = args[0] - exists, err := docker.ImageExists(imageName) + exists, err := docker.ImageExists(ctx, imageName) if err != nil { return fmt.Errorf("Failed to determine if %s exists: %w", imageName, err) } if !exists { console.Infof("Pulling image: %s", imageName) - if err := docker.Pull(imageName); err != nil { + if err := docker.Pull(ctx, imageName); err != nil { return fmt.Errorf("Failed to pull %s: %w", imageName, err) } } - conf, err := image.GetConfig(imageName) + conf, err := image.GetConfig(ctx, imageName) if err != nil { return err } @@ -111,7 +114,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { console.Infof("Starting Docker image %s...", imageName) dockerCommand := docker.NewDockerCommand() - predictor, err := predict.NewPredictor(docker.RunOptions{ + predictor, err := predict.NewPredictor(ctx, docker.RunOptions{ GPUs: gpus, Image: imageName, Volumes: volumes, @@ -129,19 +132,20 @@ func cmdTrain(cmd *cobra.Command, args []string) error { <-captureSignal console.Info("Stopping container...") - if err := predictor.Stop(); err != nil { + if err := predictor.Stop(ctx); err != nil { console.Warnf("Failed to stop container: %s", err) } }() - if err := predictor.Start(os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil { + if err := predictor.Start(ctx, os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil { return err } // FIXME: will not run on signal defer func() { console.Debugf("Stopping container...") - if err := predictor.Stop(); err != nil { + // use background context to ensure stop signal is still sent after root context is canceled + if err := predictor.Stop(context.Background()); err != nil { console.Warnf("Failed to stop container: %s", err) } }() diff --git a/pkg/docker/apt.go b/pkg/docker/apt.go index 01955d16f9..10a69643ff 100644 --- a/pkg/docker/apt.go +++ b/pkg/docker/apt.go @@ -1,6 +1,7 @@ package docker import ( + "context" "crypto/sha256" "encoding/hex" "errors" @@ -16,7 +17,7 @@ import ( const aptTarballPrefix = "apt." const aptTarballSuffix = ".tar.zst" -func CreateAptTarball(tmpDir string, dockerCommand command.Command, packages ...string) (string, error) { +func CreateAptTarball(ctx context.Context, tmpDir string, dockerCommand command.Command, packages ...string) (string, error) { if len(packages) > 0 { sort.Strings(packages) hash := sha256.New() @@ -33,7 +34,7 @@ func CreateAptTarball(tmpDir string, dockerCommand command.Command, packages ... } // Create the apt tar file - _, err = dockerCommand.CreateAptTarFile(tmpDir, aptTarFile, packages...) + _, err = dockerCommand.CreateAptTarFile(ctx, tmpDir, aptTarFile, packages...) if err != nil { return "", err } diff --git a/pkg/docker/apt_test.go b/pkg/docker/apt_test.go index a806b82808..71dc17247e 100644 --- a/pkg/docker/apt_test.go +++ b/pkg/docker/apt_test.go @@ -12,7 +12,7 @@ import ( func TestCreateAptTarball(t *testing.T) { dir := t.TempDir() command := dockertest.NewMockCommand() - tarball, err := CreateAptTarball(dir, command, []string{}...) + tarball, err := CreateAptTarball(t.Context(), dir, command, []string{}...) require.NoError(t, err) require.Equal(t, "", tarball) } @@ -20,7 +20,7 @@ func TestCreateAptTarball(t *testing.T) { func TestCreateAptTarballWithPackages(t *testing.T) { dir := t.TempDir() command := dockertest.NewMockCommand() - tarball, err := CreateAptTarball(dir, command, []string{"git"}...) + tarball, err := CreateAptTarball(t.Context(), dir, command, []string{"git"}...) require.NoError(t, err) require.True(t, strings.HasPrefix(tarball, "apt.")) } diff --git a/pkg/docker/build.go b/pkg/docker/build.go index 93e4ce7921..2ccce954f1 100644 --- a/pkg/docker/build.go +++ b/pkg/docker/build.go @@ -1,6 +1,7 @@ package docker import ( + "context" "fmt" "os" "os/exec" @@ -13,7 +14,7 @@ import ( "github.com/replicate/cog/pkg/util/console" ) -func Build(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64, contextDir string, buildContexts map[string]string) error { +func Build(ctx context.Context, dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64, contextDir string, buildContexts map[string]string) error { args := []string{ "buildx", "build", // disable provenance attestations since we don't want them cluttering the registry @@ -65,7 +66,7 @@ func Build(dir, dockerfileContents, imageName string, secrets []string, noCache contextDir, ) - cmd := exec.Command("docker", args...) + cmd := exec.CommandContext(ctx, "docker", args...) cmd.Dir = dir cmd.Stdout = os.Stderr // redirect stdout to stderr - build output is all messaging cmd.Stderr = os.Stderr @@ -75,7 +76,7 @@ func Build(dir, dockerfileContents, imageName string, secrets []string, noCache return cmd.Run() } -func BuildAddLabelsAndSchemaToImage(image string, labels map[string]string, bundledSchemaFile string, bundledSchemaPy string) error { +func BuildAddLabelsAndSchemaToImage(ctx context.Context, image string, labels map[string]string, bundledSchemaFile string, bundledSchemaPy string) error { args := []string{ "buildx", "build", // disable provenance attestations since we don't want them cluttering the registry @@ -98,7 +99,7 @@ func BuildAddLabelsAndSchemaToImage(image string, labels map[string]string, bund } // We're not using context, but Docker requires we pass a context args = append(args, ".") - cmd := exec.Command("docker", args...) + cmd := exec.CommandContext(ctx, "docker", args...) dockerfile := "FROM " + image + "\n" dockerfile += "COPY " + bundledSchemaFile + " .cog\n" diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 6ee4b4b2ab..b4c54d4685 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -1,10 +1,12 @@ package command +import "context" + type Command interface { - Pull(string) error - Push(string) error - LoadUserInformation(string) (*UserInfo, error) - CreateTarFile(string, string, string, string) (string, error) - CreateAptTarFile(string, string, ...string) (string, error) - Inspect(string) (*Manifest, error) + Pull(ctx context.Context, ref string) error + Push(ctx context.Context, ref string) error + LoadUserInformation(ctx context.Context, registryHost string) (*UserInfo, error) + CreateTarFile(ctx context.Context, ref string, tmpDir string, tarFile string, folder string) (string, error) + CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) + Inspect(ctx context.Context, ref string) (*Manifest, error) } diff --git a/pkg/docker/container_inspect.go b/pkg/docker/container_inspect.go index cfbd22639e..42be23a05d 100644 --- a/pkg/docker/container_inspect.go +++ b/pkg/docker/container_inspect.go @@ -1,6 +1,7 @@ package docker import ( + "context" "encoding/json" "fmt" "os" @@ -9,8 +10,8 @@ import ( "github.com/docker/docker/api/types" ) -func ContainerInspect(id string) (*types.ContainerJSON, error) { - cmd := exec.Command("docker", "container", "inspect", id) +func ContainerInspect(ctx context.Context, id string) (*types.ContainerJSON, error) { + cmd := exec.CommandContext(ctx, "docker", "container", "inspect", id) cmd.Env = os.Environ() out, err := cmd.Output() diff --git a/pkg/docker/docker_command.go b/pkg/docker/docker_command.go index 429768b240..5d48bf243c 100644 --- a/pkg/docker/docker_command.go +++ b/pkg/docker/docker_command.go @@ -2,6 +2,7 @@ package docker import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -32,17 +33,17 @@ func NewDockerCommand() *DockerCommand { return &DockerCommand{} } -func (c *DockerCommand) Pull(image string) error { - _, err := c.exec("pull", false, image, "--platform", "linux/amd64") +func (c *DockerCommand) Pull(ctx context.Context, image string) error { + _, err := c.exec(ctx, "pull", false, image, "--platform", "linux/amd64") return err } -func (c *DockerCommand) Push(image string) error { - _, err := c.exec("push", false, image) +func (c *DockerCommand) Push(ctx context.Context, image string) error { + _, err := c.exec(ctx, "push", false, image) return err } -func (c *DockerCommand) LoadUserInformation(registryHost string) (*command.UserInfo, error) { +func (c *DockerCommand) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { conf := config.LoadDefaultConfigFile(os.Stderr) credsStore := conf.CredentialsStore if credsStore == "" { @@ -55,7 +56,7 @@ func (c *DockerCommand) LoadUserInformation(registryHost string) (*command.UserI Username: authConf.Username, }, nil } - credsHelper, err := loadAuthFromCredentialsStore(credsStore, registryHost) + credsHelper, err := loadAuthFromCredentialsStore(ctx, credsStore, registryHost) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (c *DockerCommand) LoadUserInformation(registryHost string) (*command.UserI }, nil } -func (c *DockerCommand) CreateTarFile(image string, tmpDir string, tarFile string, folder string) (string, error) { +func (c *DockerCommand) CreateTarFile(ctx context.Context, image string, tmpDir string, tarFile string, folder string) (string, error) { args := []string{ "--rm", "--volume", @@ -76,14 +77,14 @@ func (c *DockerCommand) CreateTarFile(image string, tmpDir string, tarFile strin "/", folder, } - _, err := c.exec("run", false, args...) + _, err := c.exec(ctx, "run", false, args...) if err != nil { return "", err } return filepath.Join(tmpDir, tarFile), nil } -func (c *DockerCommand) CreateAptTarFile(tmpDir string, aptTarFile string, packages ...string) (string, error) { +func (c *DockerCommand) CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) { // This uses a hardcoded monobase image to produce an apt tar file. // The reason being that this apt tar file is created outside the docker file, and it is created by // running the apt.sh script on the monobase with the packages we intend to install, which produces @@ -97,7 +98,7 @@ func (c *DockerCommand) CreateAptTarFile(tmpDir string, aptTarFile string, packa "/buildtmp/" + aptTarFile, } args = append(args, packages...) - _, err := c.exec("run", false, args...) + _, err := c.exec(ctx, "run", false, args...) if err != nil { return "", err } @@ -105,12 +106,12 @@ func (c *DockerCommand) CreateAptTarFile(tmpDir string, aptTarFile string, packa return aptTarFile, nil } -func (c *DockerCommand) Inspect(image string) (*command.Manifest, error) { +func (c *DockerCommand) Inspect(ctx context.Context, image string) (*command.Manifest, error) { args := []string{ "inspect", image, } - manifestData, err := c.exec("image", true, args...) + manifestData, err := c.exec(ctx, "image", true, args...) if err != nil { return nil, err } @@ -128,14 +129,14 @@ func (c *DockerCommand) Inspect(image string) (*command.Manifest, error) { return &manifests[0], nil // Docker inspect returns us a list of manifests } -func (c *DockerCommand) exec(name string, capture bool, args ...string) (string, error) { +func (c *DockerCommand) exec(ctx context.Context, name string, capture bool, args ...string) (string, error) { cmdArgs := []string{name} if slices.ContainsString(commandsRequiringPlatform, name) && util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) { cmdArgs = append(cmdArgs, "--platform", "linux/amd64") } cmdArgs = append(cmdArgs, args...) dockerCmd := DockerCommandFromEnvironment() - cmd := exec.Command(dockerCmd, cmdArgs...) + cmd := exec.CommandContext(ctx, dockerCmd, cmdArgs...) var out strings.Builder if !capture { cmd.Stdout = os.Stdout @@ -157,10 +158,10 @@ func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types return conf.AuthConfigs[registryHost], nil } -func loadAuthFromCredentialsStore(credsStore string, registryHost string) (*CredentialHelperInput, error) { +func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) { var out strings.Builder binary := DockerCredentialBinary(credsStore) - cmd := exec.Command(binary, "get") + cmd := exec.CommandContext(ctx, binary, "get") cmd.Env = os.Environ() cmd.Stdout = &out cmd.Stderr = &out diff --git a/pkg/docker/docker_command_test.go b/pkg/docker/docker_command_test.go index 12d124c591..1c287a2358 100644 --- a/pkg/docker/docker_command_test.go +++ b/pkg/docker/docker_command_test.go @@ -10,6 +10,6 @@ func TestDockerPush(t *testing.T) { t.Setenv(DockerCommandEnvVarName, "echo") command := NewDockerCommand() - err := command.Push("test") + err := command.Push(t.Context(), "test") require.NoError(t, err) } diff --git a/pkg/docker/dockertest/mock_command.go b/pkg/docker/dockertest/mock_command.go index 9139dc7bed..f1aa68e9af 100644 --- a/pkg/docker/dockertest/mock_command.go +++ b/pkg/docker/dockertest/mock_command.go @@ -1,6 +1,7 @@ package dockertest import ( + "context" "os" "path/filepath" @@ -17,15 +18,15 @@ func NewMockCommand() *MockCommand { return &MockCommand{} } -func (c *MockCommand) Pull(image string) error { +func (c *MockCommand) Pull(ctx context.Context, image string) error { return nil } -func (c *MockCommand) Push(image string) error { +func (c *MockCommand) Push(ctx context.Context, image string) error { return PushError } -func (c *MockCommand) LoadUserInformation(registryHost string) (*command.UserInfo, error) { +func (c *MockCommand) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { userInfo := command.UserInfo{ Token: "", Username: "", @@ -33,7 +34,7 @@ func (c *MockCommand) LoadUserInformation(registryHost string) (*command.UserInf return &userInfo, nil } -func (c *MockCommand) CreateTarFile(image string, tmpDir string, tarFile string, folder string) (string, error) { +func (c *MockCommand) CreateTarFile(ctx context.Context, image string, tmpDir string, tarFile string, folder string) (string, error) { path := filepath.Join(tmpDir, tarFile) d1 := []byte("hello\ngo\n") err := os.WriteFile(path, d1, 0o644) @@ -43,7 +44,7 @@ func (c *MockCommand) CreateTarFile(image string, tmpDir string, tarFile string, return path, nil } -func (c *MockCommand) CreateAptTarFile(tmpDir string, aptTarFile string, packages ...string) (string, error) { +func (c *MockCommand) CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) { path := filepath.Join(tmpDir, aptTarFile) d1 := []byte("hello\ngo\n") err := os.WriteFile(path, d1, 0o644) @@ -53,7 +54,7 @@ func (c *MockCommand) CreateAptTarFile(tmpDir string, aptTarFile string, package return path, nil } -func (c *MockCommand) Inspect(image string) (*command.Manifest, error) { +func (c *MockCommand) Inspect(ctx context.Context, image string) (*command.Manifest, error) { manifest := command.Manifest{ Config: command.Config{ Labels: map[string]string{ diff --git a/pkg/docker/fast_push.go b/pkg/docker/fast_push.go index 5fe2a563cd..f67bd92766 100644 --- a/pkg/docker/fast_push.go +++ b/pkg/docker/fast_push.go @@ -80,7 +80,7 @@ func FastPush(ctx context.Context, image string, projectDir string, command comm tmpTarballsDir := filepath.Join(projectDir, TarballsDir) // Upload python packages. if requirementsFile != "" { - pythonTar, err := createPythonPackagesTarFile(image, tmpTarballsDir, command) + pythonTar, err := createPythonPackagesTarFile(ctx, image, tmpTarballsDir, command) if err != nil { return err } @@ -111,7 +111,7 @@ func FastPush(ctx context.Context, image string, projectDir string, command comm } // Upload user /src. - srcTar, err := createSrcTarFile(image, tmpTarballsDir, command) + srcTar, err := createSrcTarFile(ctx, image, tmpTarballsDir, command) if err != nil { return fmt.Errorf("create src tarfile: %w", err) } @@ -146,12 +146,12 @@ func FastPush(ctx context.Context, image string, projectDir string, command comm return webClient.PostNewVersion(ctx, image, weightFiles, files, challenges) } -func createPythonPackagesTarFile(image string, tmpDir string, command command.Command) (string, error) { - return command.CreateTarFile(image, tmpDir, requirementsTarFile, "root/.venv") +func createPythonPackagesTarFile(ctx context.Context, image string, tmpDir string, command command.Command) (string, error) { + return command.CreateTarFile(ctx, image, tmpDir, requirementsTarFile, "root/.venv") } -func createSrcTarFile(image string, tmpDir string, command command.Command) (string, error) { - return command.CreateTarFile(image, tmpDir, "src.tar.zst", "src") +func createSrcTarFile(ctx context.Context, image string, tmpDir string, command command.Command) (string, error) { + return command.CreateTarFile(ctx, image, tmpDir, "src.tar.zst", "src") } func createWeightsFilesFromWeightsManifest(weights []weights.Weight) []web.File { diff --git a/pkg/docker/fast_push_test.go b/pkg/docker/fast_push_test.go index 8c3601e2fc..74467b1394 100644 --- a/pkg/docker/fast_push_test.go +++ b/pkg/docker/fast_push_test.go @@ -1,7 +1,6 @@ package docker import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -70,13 +69,13 @@ func TestFastPush(t *testing.T) { // Setup mock command command := dockertest.NewMockCommand() - client, err := r8HTTP.ProvideHTTPClient(command) + client, err := r8HTTP.ProvideHTTPClient(t.Context(), command) require.NoError(t, err) webClient := web.NewClient(command, client) monobeamClient := monobeam.NewClient(client) // Run fast push - err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient) + err = FastPush(t.Context(), "r8.im/username/modelname", dir, command, webClient, monobeamClient) require.NoError(t, err) } @@ -143,12 +142,12 @@ func TestFastPushWithWeight(t *testing.T) { // Setup mock command command := dockertest.NewMockCommand() - client, err := r8HTTP.ProvideHTTPClient(command) + client, err := r8HTTP.ProvideHTTPClient(t.Context(), command) require.NoError(t, err) webClient := web.NewClient(command, client) monobeamClient := monobeam.NewClient(client) // Run fast push - err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient) + err = FastPush(t.Context(), "r8.im/username/modelname", dir, command, webClient, monobeamClient) require.NoError(t, err) } diff --git a/pkg/docker/image_exists.go b/pkg/docker/image_exists.go index 45c35426b6..01ce8e8382 100644 --- a/pkg/docker/image_exists.go +++ b/pkg/docker/image_exists.go @@ -1,7 +1,9 @@ package docker -func ImageExists(id string) (bool, error) { - _, err := ImageInspect(id) +import "context" + +func ImageExists(ctx context.Context, id string) (bool, error) { + _, err := ImageInspect(ctx, id) if err == ErrNoSuchImage { return false, nil } diff --git a/pkg/docker/image_inspect.go b/pkg/docker/image_inspect.go index e798dab00d..034ac1d2f7 100644 --- a/pkg/docker/image_inspect.go +++ b/pkg/docker/image_inspect.go @@ -1,6 +1,7 @@ package docker import ( + "context" "encoding/json" "errors" "os" @@ -14,8 +15,8 @@ import ( var ErrNoSuchImage = errors.New("No image returned") -func ImageInspect(id string) (*types.ImageInspect, error) { - cmd := exec.Command("docker", "image", "inspect", id) +func ImageInspect(ctx context.Context, id string) (*types.ImageInspect, error) { + cmd := exec.CommandContext(ctx, "docker", "image", "inspect", id) cmd.Env = os.Environ() console.Debug("$ " + strings.Join(cmd.Args, " ")) out, err := cmd.Output() diff --git a/pkg/docker/login.go b/pkg/docker/login.go index 44506e107d..e17f2e123c 100644 --- a/pkg/docker/login.go +++ b/pkg/docker/login.go @@ -1,6 +1,7 @@ package docker import ( + "context" "encoding/json" "fmt" "os" @@ -14,13 +15,13 @@ import ( "github.com/replicate/cog/pkg/util/console" ) -func SaveLoginToken(registryHost string, username string, token string) error { +func SaveLoginToken(ctx context.Context, registryHost string, username string, token string) error { conf := config.LoadDefaultConfigFile(os.Stderr) credsStore := conf.CredentialsStore if credsStore == "" { return saveAuthToConfig(conf, registryHost, username, token) } - return saveAuthToCredentialsStore(credsStore, registryHost, username, token) + return saveAuthToCredentialsStore(ctx, credsStore, registryHost, username, token) } func saveAuthToConfig(conf *configfile.ConfigFile, registryHost string, username string, token string) error { @@ -35,14 +36,14 @@ func saveAuthToConfig(conf *configfile.ConfigFile, registryHost string, username return nil } -func saveAuthToCredentialsStore(credsStore string, registryHost string, username string, token string) error { +func saveAuthToCredentialsStore(ctx context.Context, credsStore string, registryHost string, username string, token string) error { binary := DockerCredentialBinary(credsStore) input := CredentialHelperInput{ Username: username, Secret: token, ServerURL: registryHost, } - cmd := exec.Command(binary, "store") + cmd := exec.CommandContext(ctx, binary, "store") cmd.Env = os.Environ() cmd.Stderr = os.Stderr stdin, err := cmd.StdinPipe() diff --git a/pkg/docker/logs.go b/pkg/docker/logs.go index 21c06c47aa..d348b2cd64 100644 --- a/pkg/docker/logs.go +++ b/pkg/docker/logs.go @@ -1,13 +1,14 @@ package docker import ( + "context" "io" "os" "os/exec" ) -func ContainerLogsFollow(containerID string, out io.Writer) error { - cmd := exec.Command("docker", "container", "logs", "--follow", containerID) +func ContainerLogsFollow(ctx context.Context, containerID string, out io.Writer) error { + cmd := exec.CommandContext(ctx, "docker", "container", "logs", "--follow", containerID) cmd.Env = os.Environ() cmd.Stdout = out cmd.Stderr = out diff --git a/pkg/docker/manifest_inspect.go b/pkg/docker/manifest_inspect.go index 7f17763064..55fad0f77f 100644 --- a/pkg/docker/manifest_inspect.go +++ b/pkg/docker/manifest_inspect.go @@ -1,14 +1,15 @@ package docker import ( + "context" "os/exec" "strings" "github.com/replicate/cog/pkg/util/console" ) -func ManifestInspect(image string) error { - cmd := exec.Command("docker", "manifest", "inspect", image) +func ManifestInspect(ctx context.Context, image string) error { + cmd := exec.CommandContext(ctx, "docker", "manifest", "inspect", image) var out strings.Builder cmd.Stdout = &out cmd.Stderr = &out diff --git a/pkg/docker/pull.go b/pkg/docker/pull.go index 1cd63cc52e..e550c694e3 100644 --- a/pkg/docker/pull.go +++ b/pkg/docker/pull.go @@ -1,6 +1,7 @@ package docker import ( + "context" "os" "os/exec" "strings" @@ -8,8 +9,8 @@ import ( "github.com/replicate/cog/pkg/util/console" ) -func Pull(image string) error { - cmd := exec.Command("docker", "pull", image) +func Pull(ctx context.Context, image string) error { + cmd := exec.CommandContext(ctx, "docker", "pull", image) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/pkg/docker/push.go b/pkg/docker/push.go index 7fcd93fc63..0024a989bc 100644 --- a/pkg/docker/push.go +++ b/pkg/docker/push.go @@ -16,9 +16,8 @@ type BuildInfo struct { BuildID string } -func Push(image string, fast bool, projectDir string, command command.Command, buildInfo BuildInfo) error { - ctx := context.Background() - client, err := http.ProvideHTTPClient(command) +func Push(ctx context.Context, image string, fast bool, projectDir string, command command.Command, buildInfo BuildInfo) error { + client, err := http.ProvideHTTPClient(ctx, command) if err != nil { return err } @@ -36,5 +35,5 @@ func Push(image string, fast bool, projectDir string, command command.Command, b } return FastPush(ctx, image, projectDir, command, webClient, monobeamClient) } - return StandardPush(image, command) + return StandardPush(ctx, image, command) } diff --git a/pkg/docker/push_test.go b/pkg/docker/push_test.go index 3d14051921..d9dfbbb385 100644 --- a/pkg/docker/push_test.go +++ b/pkg/docker/push_test.go @@ -71,7 +71,7 @@ func TestPush(t *testing.T) { command := dockertest.NewMockCommand() // Run fast push - err = Push("r8.im/username/modelname", true, dir, command, BuildInfo{}) + err = Push(t.Context(), "r8.im/username/modelname", true, dir, command, BuildInfo{}) require.NoError(t, err) } @@ -140,6 +140,6 @@ func TestPushWithWeight(t *testing.T) { command := dockertest.NewMockCommand() // Run fast push - err = Push("r8.im/username/modelname", true, dir, command, BuildInfo{}) + err = Push(t.Context(), "r8.im/username/modelname", true, dir, command, BuildInfo{}) require.NoError(t, err) } diff --git a/pkg/docker/run.go b/pkg/docker/run.go index aa945d7c62..f475f4b5c1 100644 --- a/pkg/docker/run.go +++ b/pkg/docker/run.go @@ -3,6 +3,7 @@ package docker import ( "bufio" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -108,11 +109,11 @@ func generateEnv(options internalRunOptions) []string { return env } -func Run(options RunOptions) error { - return RunWithIO(options, os.Stdin, os.Stdout, os.Stderr) +func Run(ctx context.Context, options RunOptions) error { + return RunWithIO(ctx, options, os.Stdin, os.Stdout, os.Stderr) } -func RunWithIO(options RunOptions, stdin io.Reader, stdout, stderr io.Writer) error { +func RunWithIO(ctx context.Context, options RunOptions, stdin io.Reader, stdout, stderr io.Writer) error { internalOptions := internalRunOptions{RunOptions: options} if stdin != nil { internalOptions.Interactive = true @@ -124,7 +125,7 @@ func RunWithIO(options RunOptions, stdin io.Reader, stdout, stderr io.Writer) er stderrMultiWriter := io.MultiWriter(stderr, stderrCopy) dockerArgs := generateDockerArgs(internalOptions) - cmd := exec.Command("docker", dockerArgs...) + cmd := exec.CommandContext(ctx, "docker", dockerArgs...) cmd.Env = generateEnv(internalOptions) cmd.Stdout = stdout cmd.Stdin = stdin @@ -142,7 +143,7 @@ func RunWithIO(options RunOptions, stdin io.Reader, stdout, stderr io.Writer) er return nil } -func RunDaemon(options RunOptions, stderr io.Writer) (string, error) { +func RunDaemon(ctx context.Context, options RunOptions, stderr io.Writer) (string, error) { internalOptions := internalRunOptions{RunOptions: options} internalOptions.Detach = true @@ -150,7 +151,7 @@ func RunDaemon(options RunOptions, stderr io.Writer) (string, error) { stderrMultiWriter := io.MultiWriter(stderr, stderrCopy) dockerArgs := generateDockerArgs(internalOptions) - cmd := exec.Command("docker", dockerArgs...) + cmd := exec.CommandContext(ctx, "docker", dockerArgs...) cmd.Env = generateEnv(internalOptions) cmd.Stderr = stderrMultiWriter @@ -170,8 +171,8 @@ func RunDaemon(options RunOptions, stderr io.Writer) (string, error) { return strings.TrimSpace(string(containerID)), nil } -func GetPort(containerID string, containerPort int) (int, error) { - cmd := exec.Command("docker", "port", containerID, fmt.Sprintf("%d", containerPort)) //#nosec G204 +func GetPort(ctx context.Context, containerID string, containerPort int) (int, error) { + cmd := exec.CommandContext(ctx, "docker", "port", containerID, fmt.Sprintf("%d", containerPort)) //#nosec G204 cmd.Env = os.Environ() cmd.Stderr = os.Stderr @@ -211,9 +212,9 @@ func GetPort(containerID string, containerPort int) (int, error) { } -func FillInWeightsManifestVolumes(dockerCommand command.Command, runOptions RunOptions) (RunOptions, error) { +func FillInWeightsManifestVolumes(ctx context.Context, dockerCommand command.Command, runOptions RunOptions) (RunOptions, error) { // Check if the image has a weights manifest - manifest, err := dockerCommand.Inspect(runOptions.Image) + manifest, err := dockerCommand.Inspect(ctx, runOptions.Image) if err != nil { return runOptions, err } diff --git a/pkg/docker/standard_push.go b/pkg/docker/standard_push.go index 4c4a217919..c301637990 100644 --- a/pkg/docker/standard_push.go +++ b/pkg/docker/standard_push.go @@ -1,14 +1,15 @@ package docker import ( + "context" "strings" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/util" ) -func StandardPush(image string, command command.Command) error { - err := command.Push(image) +func StandardPush(ctx context.Context, image string, command command.Command) error { + err := command.Push(ctx, image) if err != nil && strings.Contains(err.Error(), "NAME_UNKNOWN") { return util.WrapError(err, "Bad response from registry: 404") } diff --git a/pkg/docker/standard_push_test.go b/pkg/docker/standard_push_test.go index d960af22ab..0516338254 100644 --- a/pkg/docker/standard_push_test.go +++ b/pkg/docker/standard_push_test.go @@ -11,13 +11,13 @@ import ( func TestStandardPush(t *testing.T) { command := dockertest.NewMockCommand() dockertest.PushError = nil - err := StandardPush("test", command) + err := StandardPush(t.Context(), "test", command) require.NoError(t, err) } func TestStandardPushWithFullDockerCommand(t *testing.T) { t.Setenv(DockerCommandEnvVarName, "echo") command := NewDockerCommand() - err := StandardPush("test", command) + err := StandardPush(t.Context(), "test", command) require.NoError(t, err) } diff --git a/pkg/docker/stop.go b/pkg/docker/stop.go index 5583def688..af8f62bfbf 100644 --- a/pkg/docker/stop.go +++ b/pkg/docker/stop.go @@ -1,12 +1,13 @@ package docker import ( + "context" "os" "os/exec" ) -func Stop(id string) error { - cmd := exec.Command("docker", "container", "stop", "--time", "3", id) +func Stop(ctx context.Context, id string) error { + cmd := exec.CommandContext(ctx, "docker", "container", "stop", "--time", "3", id) cmd.Env = os.Environ() cmd.Stderr = os.Stderr diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index bd62de5b24..85f2860d77 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -1,6 +1,7 @@ package dockerfile import ( + "context" "encoding/json" "fmt" "strings" @@ -174,7 +175,7 @@ func NewBaseImageGenerator(cudaVersion string, pythonVersion string, torchVersio return nil, fmt.Errorf("unsupported base image configuration: CUDA: %s / Python: %s / Torch: %s", printNone(cudaVersion), printNone(pythonVersion), printNone(torchVersion)) } -func (g *BaseImageGenerator) GenerateDockerfile() (string, error) { +func (g *BaseImageGenerator) GenerateDockerfile(ctx context.Context) (string, error) { conf, err := g.makeConfig() if err != nil { return "", err @@ -187,7 +188,7 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) { useCogBaseImage := false generator.SetUseCogBaseImagePtr(&useCogBaseImage) - dockerfile, err := generator.GenerateInitialSteps() + dockerfile, err := generator.GenerateInitialSteps(ctx) if err != nil { return "", err } diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index 3a7d6f3619..31997a891d 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -42,7 +42,7 @@ func TestGenerateDockerfile(t *testing.T) { command, ) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfile() + dockerfile, err := generator.GenerateDockerfile(t.Context()) require.NoError(t, err) require.True(t, strings.Contains(dockerfile, "FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04")) } diff --git a/pkg/dockerfile/fast_generator.go b/pkg/dockerfile/fast_generator.go index 1b378bc5ad..e35ab84b41 100644 --- a/pkg/dockerfile/fast_generator.go +++ b/pkg/dockerfile/fast_generator.go @@ -1,6 +1,7 @@ package dockerfile import ( + "context" "encoding/json" "errors" "fmt" @@ -54,7 +55,7 @@ func NewFastGenerator(config *config.Config, dir string, dockerCommand command.C }, nil } -func (g *FastGenerator) GenerateInitialSteps() (string, error) { +func (g *FastGenerator) GenerateInitialSteps(ctx context.Context) (string, error) { return "", errors.New("GenerateInitialSteps not supported in FastGenerator") } @@ -66,19 +67,19 @@ func (g *FastGenerator) Cleanup() error { return nil } -func (g *FastGenerator) GenerateDockerfileWithoutSeparateWeights() (string, error) { - return g.generate() +func (g *FastGenerator) GenerateDockerfileWithoutSeparateWeights(ctx context.Context) (string, error) { + return g.generate(ctx) } -func (g *FastGenerator) GenerateModelBase() (string, error) { +func (g *FastGenerator) GenerateModelBase(ctx context.Context) (string, error) { return "", errors.New("GenerateModelBase not supported in FastGenerator") } -func (g *FastGenerator) GenerateModelBaseWithSeparateWeights(imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) { +func (g *FastGenerator) GenerateModelBaseWithSeparateWeights(ctx context.Context, imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) { return "", "", "", errors.New("GenerateModelBaseWithSeparateWeights not supported in FastGenerator") } -func (g *FastGenerator) GenerateWeightsManifest() (*weights.Manifest, error) { +func (g *FastGenerator) GenerateWeightsManifest(ctx context.Context) (*weights.Manifest, error) { return nil, errors.New("GenerateWeightsManifest not supported in FastGenerator") } @@ -141,14 +142,14 @@ func (g *FastGenerator) BuildContexts() (map[string]string, error) { }, nil } -func (g *FastGenerator) generate() (string, error) { +func (g *FastGenerator) generate(ctx context.Context) (string, error) { err := g.validateConfig() if err != nil { return "", err } // Always pull latest monobase as we rely on it for build logic - if err := g.dockerCommand.Pull(MONOBASE_IMAGE); err != nil { + if err := g.dockerCommand.Pull(ctx, MONOBASE_IMAGE); err != nil { return "", err } @@ -172,7 +173,7 @@ func (g *FastGenerator) generate() (string, error) { if err != nil { return "", err } - aptTarFile, err := g.generateAptTarball(tmpAptDir) + aptTarFile, err := g.generateAptTarball(ctx, tmpAptDir) if err != nil { return "", fmt.Errorf("generate apt tarball: %w", err) } @@ -432,8 +433,8 @@ func (g *FastGenerator) entrypoint(lines []string) ([]string, error) { }...), nil } -func (g *FastGenerator) generateAptTarball(tmpDir string) (string, error) { - return docker.CreateAptTarball(tmpDir, g.dockerCommand, g.Config.Build.SystemPackages...) +func (g *FastGenerator) generateAptTarball(ctx context.Context, tmpDir string) (string, error) { + return docker.CreateAptTarball(ctx, tmpDir, g.dockerCommand, g.Config.Build.SystemPackages...) } func (g *FastGenerator) validateConfig() error { diff --git a/pkg/dockerfile/fast_generator_test.go b/pkg/dockerfile/fast_generator_test.go index 00f551dcd9..0e92197724 100644 --- a/pkg/dockerfile/fast_generator_test.go +++ b/pkg/dockerfile/fast_generator_test.go @@ -49,7 +49,7 @@ func TestGenerate(t *testing.T) { generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "# syntax=docker/dockerfile:1-labs", dockerfileLines[0]) @@ -83,7 +83,7 @@ func TestGenerateUVCacheMount(t *testing.T) { command := dockertest.NewMockCommand() generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "RUN --mount=from=monobase,target=/buildtmp --mount=type=cache,target=/var/cache/monobase,id=monobase-cache --mount=type=cache,target=/srv/r8/monobase/uv/cache,id=uv-cache UV_CACHE_DIR=\"/srv/r8/monobase/uv/cache\" UV_LINK_MODE=copy /opt/r8/monobase/run.sh monobase.build --mini --cache=/var/cache/monobase", dockerfileLines[4]) @@ -119,7 +119,7 @@ func TestGenerateCUDA(t *testing.T) { generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "ENV R8_CUDA_VERSION=12.4", dockerfileLines[3]) @@ -154,7 +154,7 @@ func TestGeneratePythonPackages(t *testing.T) { generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "RUN --mount=from=requirements,target=/buildtmp --mount=type=bind,src=\".\",target=/src,rw --mount=type=cache,target=/srv/r8/monobase/uv/cache,id=uv-cache cd /src && UV_CACHE_DIR=\"/srv/r8/monobase/uv/cache\" UV_LINK_MODE=copy UV_COMPILE_BYTECODE=0 /opt/r8/monobase/run.sh monobase.user --requirements=/buildtmp/requirements.txt", dockerfileLines[5]) @@ -189,7 +189,7 @@ func TestGenerateVerboseEnv(t *testing.T) { generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "ENV VERBOSE=0", dockerfileLines[8]) @@ -224,7 +224,7 @@ func TestAptInstall(t *testing.T) { generator, err := NewFastGenerator(&config, dir, command, &matrix, true) require.NoError(t, err) - dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) dockerfileLines := strings.Split(dockerfile, "\n") require.Equal(t, "RUN --mount=from=apt,target=/buildtmp tar --keep-directory-symlink -xf \"/buildtmp/apt.9a881b9b9f23849475296a8cd768ea1965bc3152df7118e60c145975af6aa58a.tar.zst\" -C /", dockerfileLines[5]) diff --git a/pkg/dockerfile/generator.go b/pkg/dockerfile/generator.go index 5227592d12..d7f25d3ca0 100644 --- a/pkg/dockerfile/generator.go +++ b/pkg/dockerfile/generator.go @@ -1,21 +1,25 @@ package dockerfile -import "github.com/replicate/cog/pkg/weights" +import ( + "context" + + "github.com/replicate/cog/pkg/weights" +) type Generator interface { - GenerateInitialSteps() (string, error) + GenerateInitialSteps(ctx context.Context) (string, error) SetUseCogBaseImage(bool) SetUseCogBaseImagePtr(*bool) - GenerateModelBaseWithSeparateWeights(string) (string, string, string, error) + GenerateModelBaseWithSeparateWeights(ctx context.Context, imageName string) (string, string, string, error) Cleanup() error SetStrip(bool) SetPrecompile(bool) SetUseCudaBaseImage(string) IsUsingCogBaseImage() bool BaseImage() (string, error) - GenerateWeightsManifest() (*weights.Manifest, error) - GenerateDockerfileWithoutSeparateWeights() (string, error) - GenerateModelBase() (string, error) + GenerateWeightsManifest(ctx context.Context) (*weights.Manifest, error) + GenerateDockerfileWithoutSeparateWeights(ctx context.Context) (string, error) + GenerateModelBase(ctx context.Context) (string, error) Name() string BuildDir() (string, error) BuildContexts() (map[string]string, error) diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index b07a035896..eb28a47c86 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -1,6 +1,7 @@ package dockerfile import ( + "context" "fmt" "os" "path" @@ -129,7 +130,7 @@ func (g *StandardGenerator) SetPrecompile(precompile bool) { g.precompile = precompile } -func (g *StandardGenerator) GenerateInitialSteps() (string, error) { +func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, error) { baseImage, err := g.BaseImage() if err != nil { return "", err @@ -189,8 +190,8 @@ func (g *StandardGenerator) GenerateInitialSteps() (string, error) { return joinStringsWithoutLineSpace(steps), nil } -func (g *StandardGenerator) GenerateModelBase() (string, error) { - initialSteps, err := g.GenerateInitialSteps() +func (g *StandardGenerator) GenerateModelBase(ctx context.Context) (string, error) { + initialSteps, err := g.GenerateInitialSteps(ctx) if err != nil { return "", err } @@ -203,8 +204,8 @@ func (g *StandardGenerator) GenerateModelBase() (string, error) { } // GenerateDockerfileWithoutSeparateWeights generates a Dockerfile that doesn't write model weights to a separate layer. -func (g *StandardGenerator) GenerateDockerfileWithoutSeparateWeights() (string, error) { - base, err := g.GenerateModelBase() +func (g *StandardGenerator) GenerateDockerfileWithoutSeparateWeights(ctx context.Context) (string, error) { + base, err := g.GenerateModelBase(ctx) if err != nil { return "", err } @@ -220,12 +221,12 @@ func (g *StandardGenerator) GenerateDockerfileWithoutSeparateWeights() (string, // - dockerfile: A string that represents the Dockerfile content generated by the function. // - dockerignoreContents: A string that represents the .dockerignore content. // - err: An error object if an error occurred during Dockerfile generation; otherwise nil. -func (g *StandardGenerator) GenerateModelBaseWithSeparateWeights(imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) { +func (g *StandardGenerator) GenerateModelBaseWithSeparateWeights(ctx context.Context, imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) { weightsBase, g.modelDirs, g.modelFiles, err = g.generateForWeights() if err != nil { return "", "", "", fmt.Errorf("Failed to generate Dockerfile for model weights files: %w", err) } - initialSteps, err := g.GenerateInitialSteps() + initialSteps, err := g.GenerateInitialSteps(ctx) if err != nil { return "", "", "", err } @@ -544,7 +545,7 @@ func filterEmpty(list []string) []string { return filtered } -func (g *StandardGenerator) GenerateWeightsManifest() (*weights.Manifest, error) { +func (g *StandardGenerator) GenerateWeightsManifest(ctx context.Context) (*weights.Manifest, error) { m := weights.NewManifest() for _, dir := range g.modelDirs { diff --git a/pkg/dockerfile/standard_generator_test.go b/pkg/dockerfile/standard_generator_test.go index 9d5bbc1499..9829e471a9 100644 --- a/pkg/dockerfile/standard_generator_test.go +++ b/pkg/dockerfile/standard_generator_test.go @@ -98,7 +98,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -133,7 +133,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -177,7 +177,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -232,7 +232,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -283,7 +283,7 @@ build: gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -320,7 +320,7 @@ build: gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) fmt.Println(actual) require.Contains(t, actual, `pip install -r /tmp/requirements.txt`) @@ -383,7 +383,7 @@ predict: predict.py:Predictor return nil } - modelDockerfile, runnerDockerfile, dockerignore, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + modelDockerfile, runnerDockerfile, dockerignore, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -473,7 +473,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(false) - actual, err := gen.GenerateDockerfileWithoutSeparateWeights() + actual, err := gen.GenerateDockerfileWithoutSeparateWeights(t.Context()) require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -507,7 +507,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -544,7 +544,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -595,7 +595,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) // We add the patch version to the expected torch version @@ -654,7 +654,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -707,7 +707,7 @@ predict: predict.py:Predictor require.NoError(t, err) gen.SetUseCogBaseImage(true) gen.SetStrip(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 @@ -759,7 +759,7 @@ predict: predict.py:Predictor gen, err := NewStandardGenerator(conf, tmpDir, command) require.NoError(t, err) gen.SetUseCogBaseImage(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) require.NotContains(t, actual, "-march=native") @@ -793,7 +793,7 @@ predict: predict.py:Predictor gen.SetUseCogBaseImage(true) gen.SetStrip(true) gen.SetPrecompile(true) - _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") require.NoError(t, err) expected := `#syntax=docker/dockerfile:1.4 diff --git a/pkg/http/client.go b/pkg/http/client.go index 167af3eeb6..c9ef794877 100644 --- a/pkg/http/client.go +++ b/pkg/http/client.go @@ -1,6 +1,7 @@ package http import ( + "context" "net/http" "github.com/replicate/cog/pkg/docker/command" @@ -9,8 +10,8 @@ import ( const UserAgentHeader = "User-Agent" -func ProvideHTTPClient(dockerCommand command.Command) (*http.Client, error) { - userInfo, err := dockerCommand.LoadUserInformation(global.ReplicateRegistryHost) +func ProvideHTTPClient(ctx context.Context, dockerCommand command.Command) (*http.Client, error) { + userInfo, err := dockerCommand.LoadUserInformation(ctx, global.ReplicateRegistryHost) if err != nil { return nil, err } diff --git a/pkg/http/client_test.go b/pkg/http/client_test.go index 1c8ef2778b..c68b60a821 100644 --- a/pkg/http/client_test.go +++ b/pkg/http/client_test.go @@ -20,7 +20,7 @@ func TestClientDecoratesUserAgent(t *testing.T) { defer server.Close() command := dockertest.NewMockCommand() - client, err := ProvideHTTPClient(command) + client, err := ProvideHTTPClient(t.Context(), command) require.NoError(t, err) _, err = client.Get(server.URL) diff --git a/pkg/image/build.go b/pkg/image/build.go index e26698fc01..cd82386f30 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -36,7 +36,7 @@ var errGit = errors.New("git error") // Build a Cog model from a config // // This is separated out from docker.Build(), so that can be as close as possible to the behavior of 'docker build'. -func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string, dockerfileFile string, useCogBaseImage *bool, strip bool, precompile bool, fastFlag bool, annotations map[string]string, localImage bool) error { +func Build(ctx context.Context, cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string, dockerfileFile string, useCogBaseImage *bool, strip bool, precompile bool, fastFlag bool, annotations map[string]string, localImage bool) error { console.Infof("Building Docker image from environment in cog.yaml as %s...", imageName) if fastFlag { console.Info("Fast build enabled.") @@ -57,7 +57,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, if err != nil { return fmt.Errorf("Failed to read Dockerfile at %s: %w", dockerfileFile, err) } - if err := docker.Build(dir, string(dockerfileContents), imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, dockercontext.StandardBuildDirectory, nil); err != nil { + if err := docker.Build(ctx, dir, string(dockerfileContents), imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, dockercontext.StandardBuildDirectory, nil); err != nil { return fmt.Errorf("Failed to build Docker image: %w", err) } } else { @@ -94,7 +94,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, } if separateWeights { - weightsDockerfile, runnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(imageName) + weightsDockerfile, runnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(ctx, imageName) if err != nil { return fmt.Errorf("Failed to generate Dockerfile: %w", err) } @@ -103,14 +103,14 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, return fmt.Errorf("Failed to backup .dockerignore file: %w", err) } - weightsManifest, err := generator.GenerateWeightsManifest() + weightsManifest, err := generator.GenerateWeightsManifest(ctx) if err != nil { return fmt.Errorf("Failed to generate weights manifest: %w", err) } cachedManifest, _ := weights.LoadManifest(weightsManifestPath) changed := cachedManifest == nil || !weightsManifest.Equal(cachedManifest) if changed { - if err := buildWeightsImage(dir, weightsDockerfile, imageName+"-weights", secrets, noCache, progressOutput, contextDir, buildContexts); err != nil { + if err := buildWeightsImage(ctx, dir, weightsDockerfile, imageName+"-weights", secrets, noCache, progressOutput, contextDir, buildContexts); err != nil { return fmt.Errorf("Failed to build model weights Docker image: %w", err) } err := weightsManifest.Save(weightsManifestPath) @@ -121,15 +121,15 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, console.Info("Weights unchanged, skip rebuilding and use cached image...") } - if err := buildRunnerImage(dir, runnerDockerfile, dockerignore, imageName, secrets, noCache, progressOutput, contextDir, buildContexts); err != nil { + if err := buildRunnerImage(ctx, dir, runnerDockerfile, dockerignore, imageName, secrets, noCache, progressOutput, contextDir, buildContexts); err != nil { return fmt.Errorf("Failed to build runner Docker image: %w", err) } } else { - dockerfileContents, err := generator.GenerateDockerfileWithoutSeparateWeights() + dockerfileContents, err := generator.GenerateDockerfileWithoutSeparateWeights(ctx) if err != nil { return fmt.Errorf("Failed to generate Dockerfile: %w", err) } - if err := docker.Build(dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { + if err := docker.Build(ctx, dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { return fmt.Errorf("Failed to build Docker image: %w", err) } } @@ -146,7 +146,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, schemaJSON = data } else { console.Info("Validating model schema...") - schema, err := GenerateOpenAPISchema(imageName, cfg.Build.GPU) + schema, err := GenerateOpenAPISchema(ctx, imageName, cfg.Build.GPU) if err != nil { return fmt.Errorf("Failed to get type signature: %w", err) } @@ -185,7 +185,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, return fmt.Errorf("Failed to convert config to JSON: %w", err) } - pipFreeze, err := GeneratePipFreeze(imageName, fastFlag) + pipFreeze, err := GeneratePipFreeze(ctx, imageName, fastFlag) if err != nil { return fmt.Errorf("Failed to generate pip freeze from image: %w", err) } @@ -235,13 +235,13 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, labels[global.LabelNamespace+"cog-base-image-last-layer-idx"] = fmt.Sprintf("%d", lastLayerIndex) } - if commit, err := gitHead(dir); commit != "" && err == nil { + if commit, err := gitHead(ctx, dir); commit != "" && err == nil { labels["org.opencontainers.image.revision"] = commit } else { console.Info("Unable to determine Git commit") } - if tag, err := gitTag(dir); tag != "" && err == nil { + if tag, err := gitTag(ctx, dir); tag != "" && err == nil { labels["org.opencontainers.image.version"] = tag } else { console.Info("Unable to determine Git tag") @@ -251,13 +251,13 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, labels[key] = val } - if err := docker.BuildAddLabelsAndSchemaToImage(imageName, labels, bundledSchemaFile, bundledSchemaPy); err != nil { + if err := docker.BuildAddLabelsAndSchemaToImage(ctx, imageName, labels, bundledSchemaFile, bundledSchemaPy); err != nil { return fmt.Errorf("Failed to add labels to image: %w", err) } return nil } -func BuildBase(cfg *config.Config, dir string, useCudaBaseImage string, useCogBaseImage *bool, progressOutput string) (string, error) { +func BuildBase(ctx context.Context, cfg *config.Config, dir string, useCudaBaseImage string, useCogBaseImage *bool, progressOutput string) (string, error) { // TODO: better image management so we don't eat up disk space // https://github.com/replicate/cog/issues/80 imageName := config.BaseDockerImageName(dir) @@ -287,18 +287,18 @@ func BuildBase(cfg *config.Config, dir string, useCudaBaseImage string, useCogBa generator.SetUseCogBaseImage(*useCogBaseImage) } - dockerfileContents, err := generator.GenerateModelBase() + dockerfileContents, err := generator.GenerateModelBase(ctx) if err != nil { return "", fmt.Errorf("Failed to generate Dockerfile: %w", err) } - if err := docker.Build(dir, dockerfileContents, imageName, []string{}, false, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { + if err := docker.Build(ctx, dir, dockerfileContents, imageName, []string{}, false, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { return "", fmt.Errorf("Failed to build Docker image: %w", err) } return imageName, nil } -func isGitWorkTree(dir string) bool { - ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second) +func isGitWorkTree(ctx context.Context, dir string) bool { + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() out, err := exec.CommandContext(ctx, "git", "-C", dir, "rev-parse", "--is-inside-work-tree").Output() @@ -309,13 +309,13 @@ func isGitWorkTree(dir string) bool { return strings.TrimSpace(string(out)) == "true" } -func gitHead(dir string) (string, error) { +func gitHead(ctx context.Context, dir string) (string, error) { if v, ok := os.LookupEnv("GITHUB_SHA"); ok && v != "" { return v, nil } - if isGitWorkTree(dir) { - ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second) + if isGitWorkTree(ctx, dir) { + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() out, err := exec.CommandContext(ctx, "git", "-C", dir, "rev-parse", "HEAD").Output() @@ -329,13 +329,13 @@ func gitHead(dir string) (string, error) { return "", fmt.Errorf("Failed to find HEAD commit: %w", errGit) } -func gitTag(dir string) (string, error) { +func gitTag(ctx context.Context, dir string) (string, error) { if v, ok := os.LookupEnv("GITHUB_REF_NAME"); ok && v != "" { return v, nil } - if isGitWorkTree(dir) { - ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second) + if isGitWorkTree(ctx, dir) { + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() out, err := exec.CommandContext(ctx, "git", "-C", dir, "describe", "--tags", "--dirty").Output() @@ -349,21 +349,21 @@ func gitTag(dir string) (string, error) { return "", fmt.Errorf("Failed to find ref name: %w", errGit) } -func buildWeightsImage(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error { +func buildWeightsImage(ctx context.Context, dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error { if err := makeDockerignoreForWeightsImage(); err != nil { return fmt.Errorf("Failed to create .dockerignore file: %w", err) } - if err := docker.Build(dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { + if err := docker.Build(ctx, dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { return fmt.Errorf("Failed to build Docker image for model weights: %w", err) } return nil } -func buildRunnerImage(dir, dockerfileContents, dockerignoreContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error { +func buildRunnerImage(ctx context.Context, dir, dockerfileContents, dockerignoreContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error { if err := writeDockerignore(dockerignoreContents); err != nil { return fmt.Errorf("Failed to write .dockerignore file with weights included: %w", err) } - if err := docker.Build(dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { + if err := docker.Build(ctx, dir, dockerfileContents, imageName, secrets, noCache, progressOutput, config.BuildSourceEpochTimestamp, contextDir, buildContexts); err != nil { return fmt.Errorf("Failed to build Docker image: %w", err) } if err := restoreDockerignore(); err != nil { diff --git a/pkg/image/build_test.go b/pkg/image/build_test.go index 55a0a86f47..5cc35f375b 100644 --- a/pkg/image/build_test.go +++ b/pkg/image/build_test.go @@ -16,8 +16,8 @@ var hasGit = (func() bool { return err == nil })() -func gitRun(argv []string, t *testing.T) { - ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) +func gitRun(ctx context.Context, argv []string, t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) t.Cleanup(cancel) out, err := exec.CommandContext(ctx, "git", argv...).CombinedOutput() @@ -27,6 +27,7 @@ func gitRun(argv []string, t *testing.T) { } func setupGitWorkTree(t *testing.T) string { + ctx := t.Context() if !hasGit { t.Skip("no git executable available") return "" @@ -37,27 +38,28 @@ func setupGitWorkTree(t *testing.T) string { tmp := filepath.Join(t.TempDir(), "wd") r.NoError(os.MkdirAll(tmp, 0o755)) - gitRun([]string{"init", tmp}, t) - gitRun([]string{"-C", tmp, "config", "user.email", "cog@localhost"}, t) - gitRun([]string{"-C", tmp, "config", "user.name", "Cog Tests"}, t) - gitRun([]string{"-C", tmp, "commit", "--allow-empty", "-m", "walrus"}, t) - gitRun([]string{"-C", tmp, "tag", "-a", "v0.0.1+walrus", "-m", "walrus time"}, t) + gitRun(ctx, []string{"init", tmp}, t) + gitRun(ctx, []string{"-C", tmp, "config", "user.email", "cog@localhost"}, t) + gitRun(ctx, []string{"-C", tmp, "config", "user.name", "Cog Tests"}, t) + gitRun(ctx, []string{"-C", tmp, "commit", "--allow-empty", "-m", "walrus"}, t) + gitRun(ctx, []string{"-C", tmp, "tag", "-a", "v0.0.1+walrus", "-m", "walrus time"}, t) return tmp } func TestIsGitWorkTree(t *testing.T) { + ctx := t.Context() r := require.New(t) - r.False(isGitWorkTree("/dev/null")) - r.True(isGitWorkTree(setupGitWorkTree(t))) + r.False(isGitWorkTree(ctx, "/dev/null")) + r.True(isGitWorkTree(ctx, setupGitWorkTree(t))) } func TestGitHead(t *testing.T) { t.Run("via github env", func(t *testing.T) { t.Setenv("GITHUB_SHA", "fafafaf") - head, err := gitHead("/dev/null") + head, err := gitHead(t.Context(), "/dev/null") require.NoError(t, err) require.Equal(t, "fafafaf", head) @@ -71,7 +73,7 @@ func TestGitHead(t *testing.T) { t.Setenv("GITHUB_SHA", "") - head, err := gitHead(tmp) + head, err := gitHead(t.Context(), tmp) require.NoError(t, err) require.NotEqual(t, "", head) }) @@ -79,7 +81,7 @@ func TestGitHead(t *testing.T) { t.Run("unavailable", func(t *testing.T) { t.Setenv("GITHUB_SHA", "") - head, err := gitHead("/dev/null") + head, err := gitHead(t.Context(), "/dev/null") require.Error(t, err) require.Equal(t, "", head) }) @@ -89,7 +91,7 @@ func TestGitTag(t *testing.T) { t.Run("via github env", func(t *testing.T) { t.Setenv("GITHUB_REF_NAME", "v0.0.1+manatee") - tag, err := gitTag("/dev/null") + tag, err := gitTag(t.Context(), "/dev/null") require.NoError(t, err) require.Equal(t, "v0.0.1+manatee", tag) }) @@ -102,7 +104,7 @@ func TestGitTag(t *testing.T) { t.Setenv("GITHUB_REF_NAME", "") - tag, err := gitTag(tmp) + tag, err := gitTag(t.Context(), tmp) require.NoError(t, err) require.Equal(t, "v0.0.1+walrus", tag) }) @@ -110,7 +112,7 @@ func TestGitTag(t *testing.T) { t.Run("unavailable", func(t *testing.T) { t.Setenv("GITHUB_REF_NAME", "") - tag, err := gitTag("/dev/null") + tag, err := gitTag(t.Context(), "/dev/null") require.Error(t, err) require.Equal(t, "", tag) }) diff --git a/pkg/image/config.go b/pkg/image/config.go index cee793e4fc..5fcb18776c 100644 --- a/pkg/image/config.go +++ b/pkg/image/config.go @@ -1,6 +1,7 @@ package image import ( + "context" "encoding/json" "fmt" @@ -9,8 +10,8 @@ import ( "github.com/replicate/cog/pkg/docker/command" ) -func GetConfig(imageName string) (*config.Config, error) { - image, err := docker.ImageInspect(imageName) +func GetConfig(ctx context.Context, imageName string) (*config.Config, error) { + image, err := docker.ImageInspect(ctx, imageName) if err != nil { return nil, fmt.Errorf("Failed to inspect %s: %w", imageName, err) } diff --git a/pkg/image/openapi_schema.go b/pkg/image/openapi_schema.go index f51b6cc0d6..6c3e97d134 100644 --- a/pkg/image/openapi_schema.go +++ b/pkg/image/openapi_schema.go @@ -2,6 +2,7 @@ package image import ( "bytes" + "context" "encoding/json" "fmt" @@ -14,7 +15,7 @@ import ( // GenerateOpenAPISchema by running the image and executing Cog // This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema -func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, error) { +func GenerateOpenAPISchema(ctx context.Context, imageName string, enableGPU bool) (map[string]any, error) { var stdout bytes.Buffer var stderr bytes.Buffer @@ -24,7 +25,7 @@ func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, er gpus = "all" } - err := docker.RunWithIO(docker.RunOptions{ + err := docker.RunWithIO(ctx, docker.RunOptions{ Image: imageName, Args: []string{ "python", "-m", "cog.command.openapi_schema", @@ -36,7 +37,7 @@ func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, er console.Debug(stdout.String()) console.Debug(stderr.String()) console.Debug("Missing device driver, re-trying without GPU") - return GenerateOpenAPISchema(imageName, false) + return GenerateOpenAPISchema(ctx, imageName, false) } if err != nil { @@ -55,8 +56,8 @@ func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, er return schema, nil } -func GetOpenAPISchema(imageName string) (*openapi3.T, error) { - image, err := docker.ImageInspect(imageName) +func GetOpenAPISchema(ctx context.Context, imageName string) (*openapi3.T, error) { + image, err := docker.ImageInspect(ctx, imageName) if err != nil { return nil, fmt.Errorf("Failed to inspect %s: %w", imageName, err) } diff --git a/pkg/image/pip_freeze.go b/pkg/image/pip_freeze.go index cdcdf78186..fee33e8aad 100644 --- a/pkg/image/pip_freeze.go +++ b/pkg/image/pip_freeze.go @@ -2,6 +2,7 @@ package image import ( "bytes" + "context" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/util/console" @@ -9,7 +10,7 @@ import ( // GeneratePipFreeze by running a pip freeze on the image. // This will be run as part of the build process then added as a label to the image. -func GeneratePipFreeze(imageName string, fastFlag bool) (string, error) { +func GeneratePipFreeze(ctx context.Context, imageName string, fastFlag bool) (string, error) { var stdout bytes.Buffer var stderr bytes.Buffer @@ -21,7 +22,7 @@ func GeneratePipFreeze(imageName string, fastFlag bool) (string, error) { args = []string{"uv", "pip", "freeze"} env = []string{"VIRTUAL_ENV=/root/.venv"} } - err := docker.RunWithIO(docker.RunOptions{ + err := docker.RunWithIO(ctx, docker.RunOptions{ Image: imageName, Args: args, Env: env, diff --git a/pkg/monobeam/client_test.go b/pkg/monobeam/client_test.go index beb4048f3b..9d025e6f03 100644 --- a/pkg/monobeam/client_test.go +++ b/pkg/monobeam/client_test.go @@ -1,7 +1,6 @@ package monobeam import ( - "context" "net/http" "net/http/httptest" "net/url" @@ -51,15 +50,14 @@ func TestUploadFile(t *testing.T) { command := dockertest.NewMockCommand() // Setup http client - httpClient, err := r8HTTP.ProvideHTTPClient(command) + httpClient, err := r8HTTP.ProvideHTTPClient(t.Context(), command) require.NoError(t, err) client := NewClient(httpClient) - ctx := context.Background() p := mpb.New( mpb.WithRefreshRate(180 * time.Millisecond), ) - err = client.UploadFile(ctx, "weights", "111", weightPath, p, "weights - "+weightPath) + err = client.UploadFile(t.Context(), "weights", "111", weightPath, p, "weights - "+weightPath) require.NoError(t, err) } @@ -79,11 +77,10 @@ func TestPreUpload(t *testing.T) { command := dockertest.NewMockCommand() // Setup http client - httpClient, err := r8HTTP.ProvideHTTPClient(command) + httpClient, err := r8HTTP.ProvideHTTPClient(t.Context(), command) require.NoError(t, err) client := NewClient(httpClient) - ctx := context.Background() - err = client.PostPreUpload(ctx) + err = client.PostPreUpload(t.Context()) require.NoError(t, err) } diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index b2e13d0450..7a7e1bd8a8 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -2,6 +2,7 @@ package predict import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -51,7 +52,7 @@ type Predictor struct { port int } -func NewPredictor(runOptions docker.RunOptions, isTrain bool, fastFlag bool, dockerCommand command.Command) (*Predictor, error) { +func NewPredictor(ctx context.Context, runOptions docker.RunOptions, isTrain bool, fastFlag bool, dockerCommand command.Command) (*Predictor, error) { if fastFlag { console.Info("Fast predictor enabled.") } @@ -62,7 +63,7 @@ func NewPredictor(runOptions docker.RunOptions, isTrain bool, fastFlag bool, doc runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=warning") } - runOptions, err := docker.FillInWeightsManifestVolumes(dockerCommand, runOptions) + runOptions, err := docker.FillInWeightsManifestVolumes(ctx, dockerCommand, runOptions) if err != nil { return nil, err } @@ -70,24 +71,24 @@ func NewPredictor(runOptions docker.RunOptions, isTrain bool, fastFlag bool, doc return &Predictor{runOptions: runOptions, isTrain: isTrain}, nil } -func (p *Predictor) Start(logsWriter io.Writer, timeout time.Duration) error { +func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) error { var err error containerPort := 5000 p.runOptions.Ports = append(p.runOptions.Ports, docker.Port{HostPort: 0, ContainerPort: containerPort}) - p.containerID, err = docker.RunDaemon(p.runOptions, logsWriter) + p.containerID, err = docker.RunDaemon(ctx, p.runOptions, logsWriter) if err != nil { return fmt.Errorf("Failed to start container: %w", err) } - p.port, err = docker.GetPort(p.containerID, containerPort) + p.port, err = docker.GetPort(ctx, p.containerID, containerPort) if err != nil { return fmt.Errorf("Failed to determine container port: %w", err) } go func() { - if err := docker.ContainerLogsFollow(p.containerID, logsWriter); err != nil { + if err := docker.ContainerLogsFollow(ctx, p.containerID, logsWriter); err != nil { // if user hits ctrl-c we expect an error signal if !strings.Contains(err.Error(), "signal: interrupt") { console.Warnf("Error getting container logs: %s", err) @@ -95,10 +96,10 @@ func (p *Predictor) Start(logsWriter io.Writer, timeout time.Duration) error { } }() - return p.waitForContainerReady(timeout) + return p.waitForContainerReady(ctx, timeout) } -func (p *Predictor) waitForContainerReady(timeout time.Duration) error { +func (p *Predictor) waitForContainerReady(ctx context.Context, timeout time.Duration) error { url := fmt.Sprintf("http://localhost:%d/health-check", p.port) start := time.Now() @@ -110,7 +111,7 @@ func (p *Predictor) waitForContainerReady(timeout time.Duration) error { time.Sleep(100 * time.Millisecond) - cont, err := docker.ContainerInspect(p.containerID) + cont, err := docker.ContainerInspect(ctx, p.containerID) if err != nil { return fmt.Errorf("Failed to get container status: %w", err) } @@ -143,8 +144,8 @@ func (p *Predictor) waitForContainerReady(timeout time.Duration) error { } } -func (p *Predictor) Stop() error { - return docker.Stop(p.containerID) +func (p *Predictor) Stop(ctx context.Context) error { + return docker.Stop(ctx, p.containerID) } func (p *Predictor) Predict(inputs Inputs) (*Response, error) { diff --git a/pkg/update/update.go b/pkg/update/update.go index 7a537db5dd..dda1564de1 100644 --- a/pkg/update/update.go +++ b/pkg/update/update.go @@ -21,7 +21,7 @@ func isUpdateEnabled() bool { // DisplayAndCheckForRelease will display an update message if an update is available and will check for a new update in the background // The result of that check will then be displayed the next time the user runs Cog // Returns errors which the caller is assumed to ignore so as not to break the client -func DisplayAndCheckForRelease() error { +func DisplayAndCheckForRelease(ctx context.Context) error { if !isUpdateEnabled() { return fmt.Errorf("update check disabled") } @@ -37,7 +37,7 @@ func DisplayAndCheckForRelease() error { } if time.Since(s.LastChecked) > time.Hour { - startCheckingForRelease() + startCheckingForRelease(ctx) } if s.Message != "" { console.Info(s.Message) @@ -46,10 +46,10 @@ func DisplayAndCheckForRelease() error { return nil } -func startCheckingForRelease() { +func startCheckingForRelease(ctx context.Context) { go func() { console.Debugf("Checking for updates...") - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() switch r, err := checkForRelease(ctx); { case err == nil: diff --git a/pkg/web/client.go b/pkg/web/client.go index fa23e3a45c..883ff16387 100644 --- a/pkg/web/client.go +++ b/pkg/web/client.go @@ -150,7 +150,7 @@ func (c *Client) PostPushStart(ctx context.Context, pushID string, buildTime tim } func (c *Client) PostNewVersion(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) error { - version, err := c.versionFromManifest(image, weights, files, fileChallenges) + version, err := c.versionFromManifest(ctx, image, weights, files, fileChallenges) if err != nil { return util.WrapError(err, "failed to build new version from manifest") } @@ -192,8 +192,8 @@ func (c *Client) PostNewVersion(ctx context.Context, image string, weights []Fil return nil } -func (c *Client) versionFromManifest(image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) (*Version, error) { - manifest, err := c.dockerCommand.Inspect(image) +func (c *Client) versionFromManifest(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) (*Version, error) { + manifest, err := c.dockerCommand.Inspect(ctx, image) if err != nil { return nil, util.WrapError(err, "failed to inspect docker image") } diff --git a/pkg/web/client_test.go b/pkg/web/client_test.go index e0302b8b23..6a0feb53ec 100644 --- a/pkg/web/client_test.go +++ b/pkg/web/client_test.go @@ -1,7 +1,6 @@ package web import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -42,8 +41,7 @@ func TestPostNewVersion(t *testing.T) { command := dockertest.NewMockCommand() client := NewClient(command, http.DefaultClient) - ctx := context.Background() - err = client.PostNewVersion(ctx, "r8.im/user/test", []File{}, []File{}, nil) + err = client.PostNewVersion(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) require.NoError(t, err) } @@ -61,7 +59,7 @@ func TestVersionFromManifest(t *testing.T) { dockertest.MockOpenAPISchema = "{\"test\": true}" client := NewClient(command, http.DefaultClient) - version, err := client.versionFromManifest("r8.im/user/test", []File{}, []File{}, nil) + version, err := client.versionFromManifest(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) require.NoError(t, err) var openAPISchema map[string]any @@ -152,8 +150,7 @@ func TestDoFileChallenge(t *testing.T) { // Setup mock command command := dockertest.NewMockCommand() client := NewClient(command, http.DefaultClient) - ctx := context.Background() - response, err := client.InitiateAndDoFileChallenge(ctx, weights, files) + response, err := client.InitiateAndDoFileChallenge(t.Context(), weights, files) require.NoError(t, err) assert.ElementsMatch(t, response, []FileChallengeAnswer{ {