Skip to content

Commit 519c3a4

Browse files
authored
Change /unload to not wait for inflight requests (#125)
Sometimes upstreams can accept HTTP but never respond causing requests to build up waiting for a response. This can block Process.Stop() as that waits for inflight requests to finish. This change refactors the code to not wait when attempting to shutdown the process.
1 parent 9dc4bcb commit 519c3a4

File tree

7 files changed

+66
-24
lines changed

7 files changed

+66
-24
lines changed

llama-swap.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func main() {
8484
case newManager := <-reloadChan:
8585
log.Println("Config change detected, waiting for in-flight requests to complete...")
8686
// Stop old manager processes gracefully (this waits for in-flight requests)
87-
currentManager.StopProcesses()
87+
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
8888
// Now do a full shutdown to clear the process map
8989
currentManager.Shutdown()
9090
currentManager = newManager

proxy/process.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ const (
3030
StateShutdown ProcessState = ProcessState("shutdown")
3131
)
3232

33+
type StopStrategy int
34+
35+
const (
36+
StopImmediately StopStrategy = iota
37+
StopWaitForInflightRequest
38+
)
39+
3340
type Process struct {
3441
ID string
3542
config ModelConfig
@@ -313,13 +320,25 @@ func (p *Process) start() error {
313320
}
314321
}
315322

323+
// Stop will wait for inflight requests to complete before stopping the process.
316324
func (p *Process) Stop() {
317325
if !isValidTransition(p.CurrentState(), StateStopping) {
318326
return
319327
}
320328

321329
// wait for any inflight requests before proceeding
330+
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
322331
p.inFlightRequests.Wait()
332+
p.StopImmediately()
333+
}
334+
335+
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
336+
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
337+
func (p *Process) StopImmediately() {
338+
if !isValidTransition(p.CurrentState(), StateStopping) {
339+
return
340+
}
341+
323342
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
324343

325344
// calling Stop() when state is invalid is a no-op
@@ -338,7 +357,8 @@ func (p *Process) Stop() {
338357

339358
// Shutdown is called when llama-swap is shutting down. It will give a little bit
340359
// of time for any inflight requests to complete before shutting down. If the Process
341-
// is in the state of starting, it will cancel it and shut it down
360+
// is in the state of starting, it will cancel it and shut it down. Once a process is in
361+
// the StateShutdown state, it can not be started again.
342362
func (p *Process) Shutdown() {
343363
p.shutdownCancel()
344364
p.stopCommand(5 * time.Second)

proxy/process_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,24 @@ func TestProcess_ConcurrencyLimit(t *testing.T) {
372372
process.ProxyRequest(w, denied)
373373
assert.Equal(t, http.StatusTooManyRequests, w.Code)
374374
}
375+
376+
func TestProcess_StopImmediately(t *testing.T) {
377+
expectedMessage := "test_stop_immediate"
378+
config := getTestSimpleResponderConfig(expectedMessage)
379+
380+
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
381+
defer process.Stop()
382+
383+
err := process.start()
384+
assert.Nil(t, err)
385+
assert.Equal(t, process.CurrentState(), StateReady)
386+
go func() {
387+
// slow, but will get killed by StopImmediate
388+
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
389+
w := httptest.NewRecorder()
390+
process.ProxyRequest(w, req)
391+
}()
392+
<-time.After(time.Millisecond)
393+
process.StopImmediately()
394+
assert.Equal(t, process.CurrentState(), StateStopped)
395+
}

proxy/processgroup.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
7676
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
7777
}
7878

79-
func (pg *ProcessGroup) StopProcesses() {
79+
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
8080
pg.Lock()
8181
defer pg.Unlock()
82-
pg.stopProcesses()
83-
}
8482

85-
// stopProcesses stops all processes in the group
86-
func (pg *ProcessGroup) stopProcesses() {
8783
if len(pg.processes) == 0 {
8884
return
8985
}
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
9490
wg.Add(1)
9591
go func(process *Process) {
9692
defer wg.Done()
97-
process.Stop()
93+
switch strategy {
94+
case StopImmediately:
95+
process.StopImmediately()
96+
default:
97+
process.Stop()
98+
}
9899
}(process)
99100
}
100101
wg.Wait()

proxy/processgroup_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
4646

4747
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
4848
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
49-
defer pg.StopProcesses()
49+
defer pg.StopProcesses(StopWaitForInflightRequest)
5050

5151
tests := []string{"model1", "model2"}
5252

@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
7474

7575
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
7676
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
77-
defer pg.StopProcesses()
77+
defer pg.StopProcesses(StopWaitForInflightRequest)
7878

7979
tests := []string{"model3", "model4"}
8080

proxy/proxymanager.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
208208
// This is the public method safe for concurrent calls.
209209
// Unlike Shutdown, this method only stops the processes but doesn't perform
210210
// a complete shutdown, allowing for process replacement without full termination.
211-
func (pm *ProxyManager) StopProcesses() {
211+
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
212212
pm.Lock()
213213
defer pm.Unlock()
214214

@@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() {
218218
wg.Add(1)
219219
go func(processGroup *ProcessGroup) {
220220
defer wg.Done()
221-
processGroup.stopProcesses()
221+
processGroup.StopProcesses(strategy)
222222
}(processGroup)
223223
}
224224

@@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
260260
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
261261
for groupId, otherGroup := range pm.processGroups {
262262
if groupId != processGroup.id && !otherGroup.persistent {
263-
otherGroup.StopProcesses()
263+
otherGroup.StopProcesses(StopWaitForInflightRequest)
264264
}
265265
}
266266
}
@@ -504,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
504504
}
505505

506506
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
507-
pm.StopProcesses()
507+
pm.StopProcesses(StopImmediately)
508508
c.String(http.StatusOK, "OK")
509509
}
510510

proxy/proxymanager_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
2727
})
2828

2929
proxy := New(config)
30-
defer proxy.StopProcesses()
30+
defer proxy.StopProcesses(StopWaitForInflightRequest)
3131

3232
for _, modelName := range []string{"model1", "model2"} {
3333
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
@@ -63,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
6363
})
6464

6565
proxy := New(config)
66-
defer proxy.StopProcesses()
66+
defer proxy.StopProcesses(StopWaitForInflightRequest)
6767

6868
tests := []string{"model1", "model2"}
6969
for _, requestedModel := range tests {
@@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
105105
})
106106

107107
proxy := New(config)
108-
defer proxy.StopProcesses()
108+
defer proxy.StopProcesses(StopWaitForInflightRequest)
109109

110110
// make requests to load all models, loading model1 should not affect model2
111111
tests := []string{"model2", "model1"}
@@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
141141
})
142142

143143
proxy := New(config)
144-
defer proxy.StopProcesses()
144+
defer proxy.StopProcesses(StopWaitForInflightRequest)
145145

146146
results := map[string]string{}
147147

@@ -352,7 +352,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
352352

353353
// Create proxy once for all tests
354354
proxy := New(config)
355-
defer proxy.StopProcesses()
355+
defer proxy.StopProcesses(StopWaitForInflightRequest)
356356

357357
t.Run("no models loaded", func(t *testing.T) {
358358
req := httptest.NewRequest("GET", "/running", nil)
@@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
407407
})
408408

409409
proxy := New(config)
410-
defer proxy.StopProcesses()
410+
defer proxy.StopProcesses(StopWaitForInflightRequest)
411411

412412
// Create a buffer with multipart form data
413413
var b bytes.Buffer
@@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
461461
})
462462

463463
proxy := New(config)
464-
defer proxy.StopProcesses()
464+
defer proxy.StopProcesses(StopWaitForInflightRequest)
465465

466466
requestedModel := "model1"
467467

@@ -557,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
557557
for _, tt := range tests {
558558
t.Run(tt.name, func(t *testing.T) {
559559
proxy := New(config)
560-
defer proxy.StopProcesses()
560+
defer proxy.StopProcesses(StopWaitForInflightRequest)
561561

562562
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
563563
for k, v := range tt.requestHeaders {
@@ -586,7 +586,7 @@ func TestProxyManager_Upstream(t *testing.T) {
586586
})
587587

588588
proxy := New(config)
589-
defer proxy.StopProcesses()
589+
defer proxy.StopProcesses(StopWaitForInflightRequest)
590590
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
591591
rec := httptest.NewRecorder()
592592
proxy.ServeHTTP(rec, req)
@@ -604,7 +604,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
604604
})
605605

606606
proxy := New(config)
607-
defer proxy.StopProcesses()
607+
defer proxy.StopProcesses(StopWaitForInflightRequest)
608608

609609
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
610610
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))

0 commit comments

Comments
 (0)