From c8364246406b3db15243e12f11f83c98e71f9249 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Fri, 14 May 2021 14:21:26 -0700 Subject: [PATCH] Queue test Signed-off-by: andreasjansson --- pkg/docker/mock.go | 25 ++++++ pkg/server/queue.go | 24 +++--- pkg/server/queue_test.go | 165 +++++++++++++++++++++++++++++++++++++++ pkg/server/server.go | 3 +- 4 files changed, 206 insertions(+), 11 deletions(-) create mode 100644 pkg/docker/mock.go create mode 100644 pkg/server/queue_test.go diff --git a/pkg/docker/mock.go b/pkg/docker/mock.go new file mode 100644 index 0000000000..80a5e7c131 --- /dev/null +++ b/pkg/docker/mock.go @@ -0,0 +1,25 @@ +package docker + +import ( + "context" + + "github.com/replicate/cog/pkg/logger" +) + +type MockBuildFunc func(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error) + +type MockImageBuilder struct { + buildFunc MockBuildFunc +} + +func NewMockImageBuilder(buildFunc MockBuildFunc) *MockImageBuilder { + return &MockImageBuilder{buildFunc} +} + +func (m *MockImageBuilder) Build(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error) { + return m.buildFunc(ctx, dir, dockerfileContents, name, useGPU, logWriter) +} + +func (m *MockImageBuilder) Push(ctx context.Context, tag string, logWriter logger.Logger) error { + return nil +} diff --git a/pkg/server/queue.go b/pkg/server/queue.go index b2ee84f6de..246c8c539b 100644 --- a/pkg/server/queue.go +++ b/pkg/server/queue.go @@ -71,21 +71,25 @@ func NewBuildQueue(servingPlatform serving.Platform, dockerImageBuilder docker.I } } -func (q *BuildQueue) Start() { +func (q *BuildQueue) Start(ctx context.Context) { for _, arch := range []string{"cpu", "gpu"} { - go q.startHandler(arch) + go q.startHandler(ctx, arch) } } -func (q *BuildQueue) startHandler(arch string) { +func (q *BuildQueue) startHandler(ctx context.Context, arch string) { for { - job := <-q.jobChans[arch] - go func() { - sem := q.archSemaphores[arch] - sem <- struct{}{} - defer func() { <-sem }() - q.handleJob(job) - }() + select { + case job := <-q.jobChans[arch]: + go func() { + sem := q.archSemaphores[arch] + sem <- struct{}{} + defer func() { <-sem }() + q.handleJob(job) + }() + case <-ctx.Done(): + return + } } } diff --git a/pkg/server/queue_test.go b/pkg/server/queue_test.go new file mode 100644 index 0000000000..1931309289 --- /dev/null +++ b/pkg/server/queue_test.go @@ -0,0 +1,165 @@ +package server + +import ( + "bytes" + "context" + "sync/atomic" + "testing" + "time" + + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/logger" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/serving" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" +) + +func TestQueueBuild(t *testing.T) { + cpuConcurrency := 3 + gpuConcurrency := 1 + buildStartChans := map[string]chan int{ + "cpu": make(chan int), + "gpu": make(chan int), + } + buildCompleteChans := map[string]chan int{ + "cpu": make(chan int), + "gpu": make(chan int), + } + buildResultChans := map[string]chan *BuildResult{ + "cpu": make(chan *BuildResult), + "gpu": make(chan *BuildResult), + } + counters := map[string]*int32{ + "cpu": int32p(0), + "gpu": int32p(0), + } + + config := &model.Config{ + Model: "model.py:Model", + Environment: &model.Environment{ + BuildRequiresGPU: true, + Architectures: []string{"cpu", "gpu"}, + }, + Examples: []*model.Example{{ + Input: map[string]string{}, + Output: "", + }}, + } + err := config.ValidateAndCompleteConfig() + require.NoError(t, err) + + testRunFunc := func(example *serving.Example) *serving.Result { + return &serving.Result{ + Values: map[string]serving.ResultValue{ + "output": { + Buffer: bytes.NewBuffer([]byte("hello world")), + MimeType: "text/plain", + }, + }, + SetupTime: 100, + RunTime: 0.01, + UsedMemoryBytes: 5000, + UsedCPUSecs: 0.1, + } + } + servingPlatform := serving.NewMockServingPlatform(0, testRunFunc, map[string]*model.RunArgument{}) + + imageBuildFunc := func(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error) { + arch := "cpu" + if useGPU { + arch = "gpu" + } + counter := counters[arch] + value := atomic.AddInt32(counter, 1) + buildStartChans[arch] <- int(value) + completeCounter := <-buildCompleteChans[arch] + assert.Equal(t, int(value), completeCounter) + return "image-" + arch, nil + } + dockerImageBuilder := docker.NewMockImageBuilder(imageBuildFunc) + + logWriter := logger.NewConsoleLogger() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + queue := NewBuildQueue(servingPlatform, dockerImageBuilder, cpuConcurrency, gpuConcurrency) + queue.Start(ctx) + + submitBuild := func(arch string) { + result, err := queue.Build(ctx, "", "name", "id", arch, config, logWriter) + require.NoError(t, err) + buildResultChans[arch] <- result + } + + go submitBuild("cpu") + require.Equal(t, <-buildStartChans["cpu"], 1) + require.Equal(t, 1, len(queue.archSemaphores["cpu"])) + require.Equal(t, 0, len(queue.archSemaphores["gpu"])) + buildCompleteChans["cpu"] <- 1 + + result := <-buildResultChans["cpu"] + require.Equal(t, 0, len(queue.archSemaphores["cpu"])) + require.Equal(t, "image-cpu", result.image.URI) + require.Equal(t, "cpu", result.image.Arch) + + go submitBuild("cpu") + require.Equal(t, <-buildStartChans["cpu"], 2) + go submitBuild("gpu") + require.Equal(t, <-buildStartChans["gpu"], 1) + go submitBuild("cpu") + require.Equal(t, <-buildStartChans["cpu"], 3) + + require.Equal(t, 2, len(queue.archSemaphores["cpu"])) + require.Equal(t, 1, len(queue.archSemaphores["gpu"])) + + go submitBuild("cpu") + require.Equal(t, <-buildStartChans["cpu"], 4) + require.Equal(t, 3, len(queue.archSemaphores["cpu"])) + + // block + go submitBuild("cpu") + go submitBuild("gpu") + time.Sleep(100 * time.Millisecond) + require.Equal(t, 3, len(queue.archSemaphores["cpu"])) + require.Equal(t, 1, len(queue.archSemaphores["gpu"])) + + buildCompleteChans["cpu"] <- 2 + <-buildResultChans["cpu"] + + require.Equal(t, <-buildStartChans["cpu"], 5) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 3, len(queue.archSemaphores["cpu"])) + + buildCompleteChans["gpu"] <- 1 + <-buildResultChans["gpu"] + + require.Equal(t, <-buildStartChans["gpu"], 2) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, len(queue.archSemaphores["gpu"])) + + buildCompleteChans["cpu"] <- 3 + <-buildResultChans["cpu"] + + time.Sleep(10 * time.Millisecond) + require.Equal(t, 2, len(queue.archSemaphores["cpu"])) + + buildCompleteChans["cpu"] <- 4 + <-buildResultChans["cpu"] + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, len(queue.archSemaphores["cpu"])) + + buildCompleteChans["gpu"] <- 2 + <-buildResultChans["gpu"] + time.Sleep(10 * time.Millisecond) + require.Equal(t, 0, len(queue.archSemaphores["gpu"])) + + buildCompleteChans["cpu"] <- 5 + <-buildResultChans["cpu"] + time.Sleep(10 * time.Millisecond) + require.Equal(t, 0, len(queue.archSemaphores["cpu"])) +} + +func int32p(x int32) *int32 { + return &x +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b3e2305075..f23e568560 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net/http" "net/http/pprof" @@ -60,7 +61,7 @@ func NewServer(cpuConcurrency int, gpuConcurrency int, rawPostUploadHooks []stri } func (s *Server) Start(port int) error { - s.buildQueue.Start() + s.buildQueue.Start(context.Background()) router := mux.NewRouter()