Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pkg/docker/mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package docker

import (
"context"

"github.com/replicate/cog/pkg/logger"
)

type MockBuildFunc func(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error)

type MockImageBuilder struct {
buildFunc MockBuildFunc
}

func NewMockImageBuilder(buildFunc MockBuildFunc) *MockImageBuilder {
return &MockImageBuilder{buildFunc}
}

func (m *MockImageBuilder) Build(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error) {
return m.buildFunc(ctx, dir, dockerfileContents, name, useGPU, logWriter)
}

func (m *MockImageBuilder) Push(ctx context.Context, tag string, logWriter logger.Logger) error {
return nil
}
24 changes: 14 additions & 10 deletions pkg/server/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,25 @@ func NewBuildQueue(servingPlatform serving.Platform, dockerImageBuilder docker.I
}
}

func (q *BuildQueue) Start() {
func (q *BuildQueue) Start(ctx context.Context) {
for _, arch := range []string{"cpu", "gpu"} {
go q.startHandler(arch)
go q.startHandler(ctx, arch)
}
}

func (q *BuildQueue) startHandler(arch string) {
func (q *BuildQueue) startHandler(ctx context.Context, arch string) {
for {
job := <-q.jobChans[arch]
go func() {
sem := q.archSemaphores[arch]
sem <- struct{}{}
defer func() { <-sem }()
q.handleJob(job)
}()
select {
case job := <-q.jobChans[arch]:
go func() {
sem := q.archSemaphores[arch]
sem <- struct{}{}
defer func() { <-sem }()
q.handleJob(job)
}()
case <-ctx.Done():
return
}
}
}

Expand Down
165 changes: 165 additions & 0 deletions pkg/server/queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package server

import (
"bytes"
"context"
"sync/atomic"
"testing"
"time"

"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/logger"
"github.com/replicate/cog/pkg/model"
"github.com/replicate/cog/pkg/serving"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)

func TestQueueBuild(t *testing.T) {
cpuConcurrency := 3
gpuConcurrency := 1
buildStartChans := map[string]chan int{
"cpu": make(chan int),
"gpu": make(chan int),
}
buildCompleteChans := map[string]chan int{
"cpu": make(chan int),
"gpu": make(chan int),
}
buildResultChans := map[string]chan *BuildResult{
"cpu": make(chan *BuildResult),
"gpu": make(chan *BuildResult),
}
counters := map[string]*int32{
"cpu": int32p(0),
"gpu": int32p(0),
}

config := &model.Config{
Model: "model.py:Model",
Environment: &model.Environment{
BuildRequiresGPU: true,
Architectures: []string{"cpu", "gpu"},
},
Examples: []*model.Example{{
Input: map[string]string{},
Output: "",
}},
}
err := config.ValidateAndCompleteConfig()
require.NoError(t, err)

testRunFunc := func(example *serving.Example) *serving.Result {
return &serving.Result{
Values: map[string]serving.ResultValue{
"output": {
Buffer: bytes.NewBuffer([]byte("hello world")),
MimeType: "text/plain",
},
},
SetupTime: 100,
RunTime: 0.01,
UsedMemoryBytes: 5000,
UsedCPUSecs: 0.1,
}
}
servingPlatform := serving.NewMockServingPlatform(0, testRunFunc, map[string]*model.RunArgument{})

imageBuildFunc := func(ctx context.Context, dir string, dockerfileContents string, name string, useGPU bool, logWriter logger.Logger) (tag string, err error) {
arch := "cpu"
if useGPU {
arch = "gpu"
}
counter := counters[arch]
value := atomic.AddInt32(counter, 1)
buildStartChans[arch] <- int(value)
completeCounter := <-buildCompleteChans[arch]
assert.Equal(t, int(value), completeCounter)
return "image-" + arch, nil
}
dockerImageBuilder := docker.NewMockImageBuilder(imageBuildFunc)

logWriter := logger.NewConsoleLogger()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
queue := NewBuildQueue(servingPlatform, dockerImageBuilder, cpuConcurrency, gpuConcurrency)
queue.Start(ctx)

submitBuild := func(arch string) {
result, err := queue.Build(ctx, "", "name", "id", arch, config, logWriter)
require.NoError(t, err)
buildResultChans[arch] <- result
}

go submitBuild("cpu")
require.Equal(t, <-buildStartChans["cpu"], 1)
require.Equal(t, 1, len(queue.archSemaphores["cpu"]))
require.Equal(t, 0, len(queue.archSemaphores["gpu"]))
buildCompleteChans["cpu"] <- 1

result := <-buildResultChans["cpu"]
require.Equal(t, 0, len(queue.archSemaphores["cpu"]))
require.Equal(t, "image-cpu", result.image.URI)
require.Equal(t, "cpu", result.image.Arch)

go submitBuild("cpu")
require.Equal(t, <-buildStartChans["cpu"], 2)
go submitBuild("gpu")
require.Equal(t, <-buildStartChans["gpu"], 1)
go submitBuild("cpu")
require.Equal(t, <-buildStartChans["cpu"], 3)

require.Equal(t, 2, len(queue.archSemaphores["cpu"]))
require.Equal(t, 1, len(queue.archSemaphores["gpu"]))

go submitBuild("cpu")
require.Equal(t, <-buildStartChans["cpu"], 4)
require.Equal(t, 3, len(queue.archSemaphores["cpu"]))

// block
go submitBuild("cpu")
go submitBuild("gpu")
time.Sleep(100 * time.Millisecond)
require.Equal(t, 3, len(queue.archSemaphores["cpu"]))
require.Equal(t, 1, len(queue.archSemaphores["gpu"]))

buildCompleteChans["cpu"] <- 2
<-buildResultChans["cpu"]

require.Equal(t, <-buildStartChans["cpu"], 5)
time.Sleep(10 * time.Millisecond)
require.Equal(t, 3, len(queue.archSemaphores["cpu"]))

buildCompleteChans["gpu"] <- 1
<-buildResultChans["gpu"]

require.Equal(t, <-buildStartChans["gpu"], 2)
time.Sleep(10 * time.Millisecond)
require.Equal(t, 1, len(queue.archSemaphores["gpu"]))

buildCompleteChans["cpu"] <- 3
<-buildResultChans["cpu"]

time.Sleep(10 * time.Millisecond)
require.Equal(t, 2, len(queue.archSemaphores["cpu"]))

buildCompleteChans["cpu"] <- 4
<-buildResultChans["cpu"]
time.Sleep(10 * time.Millisecond)
require.Equal(t, 1, len(queue.archSemaphores["cpu"]))

buildCompleteChans["gpu"] <- 2
<-buildResultChans["gpu"]
time.Sleep(10 * time.Millisecond)
require.Equal(t, 0, len(queue.archSemaphores["gpu"]))

buildCompleteChans["cpu"] <- 5
<-buildResultChans["cpu"]
time.Sleep(10 * time.Millisecond)
require.Equal(t, 0, len(queue.archSemaphores["cpu"]))
}

func int32p(x int32) *int32 {
return &x
}
3 changes: 2 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"fmt"
"net/http"
"net/http/pprof"
Expand Down Expand Up @@ -60,7 +61,7 @@ func NewServer(cpuConcurrency int, gpuConcurrency int, rawPostUploadHooks []stri
}

func (s *Server) Start(port int) error {
s.buildQueue.Start()
s.buildQueue.Start(context.Background())

router := mux.NewRouter()

Expand Down