diff --git a/proxy/config/config.go b/proxy/config/config.go index 97cb56bc..c6f478cd 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "io" + "net/url" "os" "regexp" "runtime" @@ -342,6 +343,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } + // Validate the proxy URL. + if _, err := url.Parse(modelConfig.Proxy); err != nil { + return Config{}, fmt.Errorf( + "model %s: invalid proxy URL: %w", modelId, err, + ) + } + config.Models[modelId] = modelConfig } diff --git a/proxy/process.go b/proxy/process.go index 51a5bc61..82400f06 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -4,12 +4,11 @@ import ( "context" "errors" "fmt" - "io" "net" "net/http" + "net/http/httputil" "net/url" "os/exec" - "strconv" "strings" "sync" "syscall" @@ -39,9 +38,10 @@ const ( ) type Process struct { - ID string - config config.ModelConfig - cmd *exec.Cmd + ID string + config config.ModelConfig + cmd *exec.Cmd + reverseProxy *httputil.ReverseProxy // PR #155 called to cancel the upstream process cancelUpstream context.CancelFunc @@ -81,10 +81,29 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr concurrentLimit = config.ConcurrencyLimit } + // Setup the reverse proxy. + proxyURL, err := url.Parse(config.Proxy) + if err != nil { + proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err) + } + + var reverseProxy *httputil.ReverseProxy + if proxyURL != nil { + reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL) + reverseProxy.ModifyResponse = func(resp *http.Response) error { + // prevent nginx from buffering streaming responses (e.g., SSE) + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { + resp.Header.Set("X-Accel-Buffering", "no") + } + return nil + } + } + return &Process{ ID: ID, config: config, cmd: nil, + reverseProxy: reverseProxy, cancelUpstream: nil, processLogger: processLogger, proxyLogger: proxyLogger, @@ -434,56 +453,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { startDuration = time.Since(beginStartTime) } - proxyTo := p.config.Proxy - client := &http.Client{} - req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - req.Header = r.Header.Clone() - - contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64) - if err == nil { - req.ContentLength = contentLength - } - - resp, err := client.Do(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - // prevent nginx from buffering streaming responses (e.g., SSE) - if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { - w.Header().Set("X-Accel-Buffering", "no") - } - w.WriteHeader(resp.StatusCode) - - // faster than io.Copy when streaming - buf := make([]byte, 32*1024) - for { - n, err := resp.Body.Read(buf) - if n > 0 { - if _, writeErr := w.Write(buf[:n]); writeErr != nil { - return - } - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() - } - } - if err == io.EOF { - break - } - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } + if p.reverseProxy != nil { + p.reverseProxy.ServeHTTP(w, r) + } else { + http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError) } totalTime := time.Since(requestBeginTime) diff --git a/proxy/process_test.go b/proxy/process_test.go index 574c5d9e..78922b30 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -436,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) { if runtime.GOOS == "windows" { assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host") } else { - assert.Contains(t, w.Body.String(), "unexpected EOF") + // Upstream may be killed mid-response. + // Assert an incomplete or partial response. + assert.NotEqual(t, "12345", w.Body.String()) } close(waitChan) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 3c8a1583..828067c3 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -21,6 +21,32 @@ import ( "github.com/tidwall/gjson" ) +// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder. +// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier." +// The tests can panic otherwise: +// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify +// See: https://github.com/gin-gonic/gin/issues/1815 +// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765 +type TestResponseRecorder struct { + *httptest.ResponseRecorder + closeChannel chan bool +} + +func (r *TestResponseRecorder) CloseNotify() <-chan bool { + return r.closeChannel +} + +func (r *TestResponseRecorder) closeClient() { + r.closeChannel <- true +} + +func CreateTestResponseRecorder() *TestResponseRecorder { + return &TestResponseRecorder{ + httptest.NewRecorder(), + make(chan bool, 1), + } +} + func TestProxyManager_SwapProcessCorrectly(t *testing.T) { config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, @@ -37,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { for _, modelName := range []string{"model1", "model2"} { reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -74,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { t.Run(requestedModel, func(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -116,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { for _, requestedModel := range tests { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -159,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s"}`, key) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) @@ -212,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { // Create a test request req := httptest.NewRequest("GET", "/v1/models", nil) req.Header.Add("Origin", "i-am-the-origin") - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() // Call the listModelsHandler proxy.ServeHTTP(w, req) @@ -311,7 +337,7 @@ models: proxy := New(processedConfig) req := httptest.NewRequest("GET", "/v1/models", nil) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -387,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { // Request models list req := httptest.NewRequest("GET", "/v1/models", nil) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -448,7 +474,7 @@ func TestProxyManager_Shutdown(t *testing.T) { defer wg.Done() reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() // send a request to trigger the proxy to load ... this should hang waiting for start up proxy.ServeHTTP(w, req) @@ -476,12 +502,12 @@ func TestProxyManager_Unload(t *testing.T) { proxy := New(conf) reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) req = httptest.NewRequest("GET", "/unload", nil) - w = httptest.NewRecorder() + w = CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, w.Body.String(), "OK") @@ -519,7 +545,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) { for _, modelName := range []string{"model1", "model2"} { reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) } @@ -527,7 +553,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) { assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState()) req := httptest.NewRequest("POST", "/api/models/unload/model1", nil) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) if !assert.Equal(t, w.Body.String(), "OK") { @@ -571,7 +597,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { t.Run("no models loaded", func(t *testing.T) { req := httptest.NewRequest("GET", "/running", nil) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -589,13 +615,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { // Load just a model. reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // Simulate browser call for the `/running` endpoint. req = httptest.NewRequest("GET", "/running", nil) - w = httptest.NewRecorder() + w = CreateTestResponseRecorder() proxy.ServeHTTP(w, req) var response RunningResponse @@ -647,7 +673,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { // Create the request with the multipart form data req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req.Header.Set("Content-Type", w.FormDataContentType()) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req) // Verify the response @@ -682,7 +708,7 @@ func TestProxyManager_UseModelName(t *testing.T) { t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -716,7 +742,7 @@ func TestProxyManager_UseModelName(t *testing.T) { // Create the request with the multipart form data req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req.Header.Set("Content-Type", w.FormDataContentType()) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req) // Verify the response @@ -784,7 +810,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { req.Header.Set(k, v) } - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) @@ -812,7 +838,7 @@ models: defer proxy.StopProcesses(StopWaitForInflightRequest) t.Run("main model name", func(t *testing.T) { req := httptest.NewRequest("GET", "/upstream/model1/test", nil) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "model1", rec.Body.String()) @@ -820,7 +846,7 @@ models: t.Run("model alias", func(t *testing.T) { req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "model1", rec.Body.String()) @@ -841,7 +867,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -869,7 +895,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) { defer proxy.StopProcesses(StopWaitForInflightRequest) reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -900,7 +926,7 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) { // Make a non-streaming request reqBody := `{"model":"model1", "stream": false}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -935,7 +961,7 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) { // Make a streaming request reqBody := `{"model":"model1", "stream": true}` req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -967,7 +993,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) { proxy := New(config) defer proxy.StopProcesses(StopWaitForInflightRequest) req := httptest.NewRequest("GET", "/health", nil) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) @@ -988,7 +1014,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) { reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -1080,7 +1106,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { req := httptest.NewRequest("GET", endpoint, nil) req = req.WithContext(ctx) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() // We don't need the handler to fully complete, just to set the headers // so run it in a goroutine and check the headers after a short delay @@ -1109,7 +1135,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin reqBody := `{"model":"streaming-model"}` // simple-responder will return text/event-stream when stream=true is in the query req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) - rec := httptest.NewRecorder() + rec := CreateTestResponseRecorder() proxy.ServeHTTP(rec, req)