Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama-swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func main() {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses()
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
Expand Down
22 changes: 21 additions & 1 deletion proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ const (
StateShutdown ProcessState = ProcessState("shutdown")
)

type StopStrategy int

const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)

type Process struct {
ID string
config ModelConfig
Expand Down Expand Up @@ -313,13 +320,25 @@ func (p *Process) start() error {
}
}

// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}

// wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait()
p.StopImmediately()
}

// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}

p.proxyLogger.Debugf("<%s> Stopping process", p.ID)

// calling Stop() when state is invalid is a no-op
Expand All @@ -338,7 +357,8 @@ func (p *Process) Stop() {

// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
// is in the state of starting, it will cancel it and shut it down. Once a process is in
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() {
p.shutdownCancel()
p.stopCommand(5 * time.Second)
Expand Down
21 changes: 21 additions & 0 deletions proxy/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,24 @@ func TestProcess_ConcurrencyLimit(t *testing.T) {
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}

func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)

process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()

err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
13 changes: 7 additions & 6 deletions proxy/processgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}

func (pg *ProcessGroup) StopProcesses() {
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
pg.stopProcesses()
}

// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 {
return
}
Expand All @@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process)
}
wg.Wait()
Expand Down
4 changes: 2 additions & 2 deletions proxy/processgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {

func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
defer pg.StopProcesses(StopWaitForInflightRequest)

tests := []string{"model1", "model2"}

Expand Down Expand Up @@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {

func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
defer pg.StopProcesses(StopWaitForInflightRequest)

tests := []string{"model3", "model4"}

Expand Down
8 changes: 4 additions & 4 deletions proxy/proxymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses() {
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()

Expand All @@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.stopProcesses()
processGroup.StopProcesses(strategy)
}(processGroup)
}

Expand Down Expand Up @@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses()
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
Expand Down Expand Up @@ -504,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}

func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
}

Expand Down
20 changes: 10 additions & 10 deletions proxy/proxymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

tests := []string{"model1", "model2"}
for _, requestedModel := range tests {
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

// make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"}
Expand Down Expand Up @@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

results := map[string]string{}

Expand Down Expand Up @@ -352,7 +352,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {

// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
Expand Down Expand Up @@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

// Create a buffer with multipart form data
var b bytes.Buffer
Expand Down Expand Up @@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

requestedModel := "model1"

Expand Down Expand Up @@ -557,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
Expand Down Expand Up @@ -586,7 +586,7 @@ func TestProxyManager_Upstream(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
Expand All @@ -604,7 +604,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
})

proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)

reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
Expand Down