From 237b07fd29b5bf1aeaf70223de5764aa82c4f27d Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Tue, 13 May 2025 11:07:29 -0700 Subject: [PATCH] Change /unload to not wait for inflight requests (#125) Sometimes upstreams can accept HTTP but never respond causing requests to build up waiting for a response. This can block Process.Stop() as that waits for inflight requests to finish. This change refactors the code to not wait when attempting to shutdown the process. --- llama-swap.go | 2 +- proxy/process.go | 22 +++++++++++++++++++++- proxy/process_test.go | 21 +++++++++++++++++++++ proxy/processgroup.go | 13 +++++++------ proxy/processgroup_test.go | 4 ++-- proxy/proxymanager.go | 8 ++++---- proxy/proxymanager_test.go | 20 ++++++++++---------- 7 files changed, 66 insertions(+), 24 deletions(-) diff --git a/llama-swap.go b/llama-swap.go index ce9c8ea..28914fb 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -84,7 +84,7 @@ func main() { case newManager := <-reloadChan: log.Println("Config change detected, waiting for in-flight requests to complete...") // Stop old manager processes gracefully (this waits for in-flight requests) - currentManager.StopProcesses() + currentManager.StopProcesses(proxy.StopWaitForInflightRequest) // Now do a full shutdown to clear the process map currentManager.Shutdown() currentManager = newManager diff --git a/proxy/process.go b/proxy/process.go index 4f07fc7..7db2d9e 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -30,6 +30,13 @@ const ( StateShutdown ProcessState = ProcessState("shutdown") ) +type StopStrategy int + +const ( + StopImmediately StopStrategy = iota + StopWaitForInflightRequest +) + type Process struct { ID string config ModelConfig @@ -313,13 +320,25 @@ func (p *Process) start() error { } } +// Stop will wait for inflight requests to complete before stopping the process. func (p *Process) Stop() { if !isValidTransition(p.CurrentState(), StateStopping) { return } // wait for any inflight requests before proceeding + p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID) p.inFlightRequests.Wait() + p.StopImmediately() +} + +// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. +// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. +func (p *Process) StopImmediately() { + if !isValidTransition(p.CurrentState(), StateStopping) { + return + } + p.proxyLogger.Debugf("<%s> Stopping process", p.ID) // calling Stop() when state is invalid is a no-op @@ -338,7 +357,8 @@ func (p *Process) Stop() { // Shutdown is called when llama-swap is shutting down. It will give a little bit // of time for any inflight requests to complete before shutting down. If the Process -// is in the state of starting, it will cancel it and shut it down +// is in the state of starting, it will cancel it and shut it down. Once a process is in +// the StateShutdown state, it can not be started again. func (p *Process) Shutdown() { p.shutdownCancel() p.stopCommand(5 * time.Second) diff --git a/proxy/process_test.go b/proxy/process_test.go index f45a404..a715215 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -372,3 +372,24 @@ func TestProcess_ConcurrencyLimit(t *testing.T) { process.ProxyRequest(w, denied) assert.Equal(t, http.StatusTooManyRequests, w.Code) } + +func TestProcess_StopImmediately(t *testing.T) { + expectedMessage := "test_stop_immediate" + config := getTestSimpleResponderConfig(expectedMessage) + + process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) + defer process.Stop() + + err := process.start() + assert.Nil(t, err) + assert.Equal(t, process.CurrentState(), StateReady) + go func() { + // slow, but will get killed by StopImmediate + req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil) + w := httptest.NewRecorder() + process.ProxyRequest(w, req) + }() + <-time.After(time.Millisecond) + process.StopImmediately() + assert.Equal(t, process.CurrentState(), StateStopped) +} diff --git a/proxy/processgroup.go b/proxy/processgroup.go index 464dded..4f10d0a 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool { return slices.Contains(pg.config.Groups[pg.id].Members, modelName) } -func (pg *ProcessGroup) StopProcesses() { +func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { pg.Lock() defer pg.Unlock() - pg.stopProcesses() -} -// stopProcesses stops all processes in the group -func (pg *ProcessGroup) stopProcesses() { if len(pg.processes) == 0 { return } @@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() { wg.Add(1) go func(process *Process) { defer wg.Done() - process.Stop() + switch strategy { + case StopImmediately: + process.StopImmediately() + default: + process.Stop() + } }(process) } wg.Wait() diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index c6d3670..8a1ace8 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) - defer pg.StopProcesses() + defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model1", "model2"} @@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) - defer pg.StopProcesses() + defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model3", "model4"} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 54fc8a7..deb017b 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { // This is the public method safe for concurrent calls. // Unlike Shutdown, this method only stops the processes but doesn't perform // a complete shutdown, allowing for process replacement without full termination. -func (pm *ProxyManager) StopProcesses() { +func (pm *ProxyManager) StopProcesses(strategy StopStrategy) { pm.Lock() defer pm.Unlock() @@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() { wg.Add(1) go func(processGroup *ProcessGroup) { defer wg.Done() - processGroup.stopProcesses() + processGroup.StopProcesses(strategy) }(processGroup) } @@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) for groupId, otherGroup := range pm.processGroups { if groupId != processGroup.id && !otherGroup.persistent { - otherGroup.StopProcesses() + otherGroup.StopProcesses(StopWaitForInflightRequest) } } } @@ -504,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag } func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { - pm.StopProcesses() + pm.StopProcesses(StopImmediately) c.String(http.StatusOK, "OK") } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 242e4a7..b0379a7 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -27,7 +27,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) for _, modelName := range []string{"model1", "model2"} { reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) @@ -63,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) tests := []string{"model1", "model2"} for _, requestedModel := range tests { @@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) // make requests to load all models, loading model1 should not affect model2 tests := []string{"model2", "model1"} @@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) results := map[string]string{} @@ -352,7 +352,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { // Create proxy once for all tests proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) t.Run("no models loaded", func(t *testing.T) { req := httptest.NewRequest("GET", "/running", nil) @@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) // Create a buffer with multipart form data var b bytes.Buffer @@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) requestedModel := "model1" @@ -557,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) for k, v := range tt.requestHeaders { @@ -586,7 +586,7 @@ func TestProxyManager_Upstream(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) req := httptest.NewRequest("GET", "/upstream/model1/test", nil) rec := httptest.NewRecorder() proxy.ServeHTTP(rec, req) @@ -604,7 +604,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) { }) proxy := New(config) - defer proxy.StopProcesses() + defer proxy.StopProcesses(StopWaitForInflightRequest) reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))