Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt

# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -count=1 ./proxy/...
go test -race -count=1 ./proxy/...

ui/node_modules:
cd ui && npm install
Expand Down
77 changes: 63 additions & 14 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -44,6 +45,7 @@ type Process struct {
cmd *exec.Cmd

// PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc

// closed when command exits
Expand All @@ -55,12 +57,14 @@ type Process struct {
healthCheckTimeout int
healthCheckLoopInterval time.Duration

lastRequestHandled time.Time
lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time

stateMutex sync.RWMutex
state ProcessState

inFlightRequests sync.WaitGroup
inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32

// used to block on multiple start() calls
waitStarting sync.WaitGroup
Expand Down Expand Up @@ -107,6 +111,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}

// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}

// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}

// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
Expand All @@ -130,6 +148,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
}

p.state = newState

// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}

p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil
Expand Down Expand Up @@ -158,6 +183,15 @@ func (p *Process) CurrentState() ProcessState {
return p.state
}

// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}

// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
Expand Down Expand Up @@ -191,7 +225,7 @@ func (p *Process) start() error {
}
}

p.waitStarting.Add(1)
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())

Expand All @@ -201,8 +235,11 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout

p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()

p.failedStartCount++ // this will be reset to zero when the process has successfully started

Expand All @@ -212,7 +249,7 @@ func (p *Process) start() error {
// Set process state to failed
if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state
p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr,
Expand Down Expand Up @@ -285,10 +322,12 @@ func (p *Process) start() error {
return
}

// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
// skip the TTL check if there are inflight requests
if p.inFlightRequestsCount.Load() != 0 {
continue
}

if time.Since(p.lastRequestHandled) > maxDuration {
if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
Expand Down Expand Up @@ -344,7 +383,7 @@ func (p *Process) Shutdown() {

p.stopCommand()
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
p.forceState(StateShutdown)
}

// stopCommand will send a SIGTERM to the process and wait for it to exit.
Expand All @@ -355,13 +394,18 @@ func (p *Process) stopCommand() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()

if p.cancelUpstream == nil {
p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()

if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return
}

p.cancelUpstream()
<-p.cmdWaitChan
cancelUpstream()
<-cmdWaitChan
}

func (p *Process) checkHealthEndpoint(healthURL string) error {
Expand Down Expand Up @@ -418,8 +462,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
}

p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done()
}()

Expand Down Expand Up @@ -519,13 +565,16 @@ func (p *Process) waitForCmd() {
case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped
p.forceState(StateStopped)
}
default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state
p.forceState(StateStopped) // force it to be in this state
}

p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdMutex.Unlock()
}

// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
Expand Down
20 changes: 15 additions & 5 deletions proxy/proxymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,18 +1075,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {

for _, endpoint := range endpoints {
t.Run(endpoint, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

req := httptest.NewRequest("GET", endpoint, nil)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()

// We don't need the handler to fully complete, just to set the headers
// so run it in a goroutine and check the headers after a short delay
go proxy.ServeHTTP(rec, req)
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
// Run handler in goroutine and wait for context timeout
done := make(chan struct{})
go func() {
defer close(done)
proxy.ServeHTTP(rec, req)
}()

// Wait for either the handler to complete or context to timeout
<-ctx.Done()

// At this point, the handler has either finished or been cancelled
// Wait for the goroutine to fully exit before reading
<-done

// Now it's safe to read from rec - no more concurrent writes
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
})
Expand Down