From 95d9f1b0459763edbf2f478abee6a2cca1b906fb Mon Sep 17 00:00:00 2001 From: Luke Van Drie Date: Thu, 18 Dec 2025 21:00:22 +0000 Subject: [PATCH] fix: harden header sanitization and handling logic This change improves the robustness of header manipulation in the ext_proc server. It implements strict sanitization for "system-owned" headers (such as Content-Length and internal routing metadata) in both the request and response paths. Previously, these headers were passed through transparently from the input. This change ensures the extension maintains authoritative control over protocol and routing headers, preventing potential ambiguity in downstream processing. --- pkg/epp/handlers/request.go | 15 +++------- pkg/epp/handlers/request_test.go | 50 +++++++++++++++++++++---------- pkg/epp/handlers/response.go | 5 +++- pkg/epp/handlers/response_test.go | 28 +++++++++++++++++ pkg/epp/util/request/headers.go | 29 ++++++++++++++++++ 5 files changed, 100 insertions(+), 27 deletions(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 346cd4c98a..f8d2ef8e32 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -60,20 +60,10 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP switch header.Key { case metadata.FlowFairnessIDKey: reqCtx.FairnessID = reqCtx.Request.Headers[header.Key] - // remove the fairness ID header from the request headers, - // this is not data that should be manipulated or sent to the backend. - // It is only used for flow control. - delete(reqCtx.Request.Headers, header.Key) case metadata.ObjectiveKey: reqCtx.ObjectiveKey = reqCtx.Request.Headers[header.Key] - // remove the objective header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) case metadata.ModelNameRewriteKey: reqCtx.TargetModelName = reqCtx.Request.Headers[header.Key] - // remove the rewrite header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) } } @@ -140,8 +130,11 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He }) } - // include all headers + // Include any non-system-owned headers. for key, value := range reqCtx.Request.Headers { + if request.IsSystemOwnedHeader(key) { + continue + } headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: key, diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go index 36b6fb6463..60d2fa194f 100644 --- a/pkg/epp/handlers/request_test.go +++ b/pkg/epp/handlers/request_test.go @@ -29,11 +29,10 @@ func TestHandleRequestHeaders(t *testing.T) { t.Parallel() tests := []struct { - name string - headers []*configPb.HeaderValue - wantHeaders map[string]string - wantFairnessID string - wantDeletedKeys []string + name string + headers []*configPb.HeaderValue + wantHeaders map[string]string + wantFairnessID string }{ { name: "Extracts Fairness ID and Removes Header", @@ -41,17 +40,15 @@ func TestHandleRequestHeaders(t *testing.T) { {Key: "x-test", Value: "val"}, {Key: metadata.FlowFairnessIDKey, Value: "user-123"}, }, - wantHeaders: map[string]string{"x-test": "val"}, - wantFairnessID: "user-123", - wantDeletedKeys: []string{metadata.FlowFairnessIDKey}, + wantHeaders: map[string]string{"x-test": "val"}, + wantFairnessID: "user-123", }, { name: "Prefers RawValue over Value", headers: []*configPb.HeaderValue{ {Key: metadata.FlowFairnessIDKey, RawValue: []byte("binary-id"), Value: "wrong-id"}, }, - wantFairnessID: "binary-id", - wantDeletedKeys: []string{metadata.FlowFairnessIDKey}, + wantFairnessID: "binary-id", }, } @@ -77,11 +74,34 @@ func TestHandleRequestHeaders(t *testing.T) { assert.Equal(t, v, reqCtx.Request.Headers[k], "Header %q should match expected value", k) } } - - for _, key := range tc.wantDeletedKeys { - _, exists := reqCtx.Request.Headers[key] - assert.False(t, exists, "Expected header %q to be removed from map", key) - } }) } } + +func TestGenerateHeaders_Sanitization(t *testing.T) { + server := &StreamingServer{} + reqCtx := &RequestContext{ + TargetEndpoint: "1.2.3.4:8080", + RequestSize: 123, + Request: &Request{ + Headers: map[string]string{ + "x-user-data": "important", // should passthrough + metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped + metadata.DestinationEndpointKey: "1.1.1.1:666", // should be stripped + "content-length": "99999", // should be stripped (re-added by logic) + }, + }, + } + + results := server.generateHeaders(reqCtx) + + gotHeaders := make(map[string]string) + for _, h := range results { + gotHeaders[h.Header.Key] = string(h.Header.RawValue) + } + + assert.Contains(t, gotHeaders, "x-user-data") + assert.NotContains(t, gotHeaders, metadata.ObjectiveKey) + assert.Equal(t, "1.2.3.4:8080", gotHeaders[metadata.DestinationEndpointKey]) + assert.Equal(t, "123", gotHeaders["Content-Length"]) +} diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index fdef8ea241..d259a7ccb3 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -155,8 +155,11 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con }, } - // include all headers + // Include any non-system-owned headers. for key, value := range reqCtx.Response.Headers { + if request.IsSystemOwnedHeader(key) { + continue + } headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: key, diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 9608d06882..1960c72523 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/assert" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -298,3 +299,30 @@ func TestHandleResponseBodyModelStreaming_TokenAccumulation(t *testing.T) { }) } } + +func TestGenerateResponseHeaders_Sanitization(t *testing.T) { + server := &StreamingServer{} + reqCtx := &RequestContext{ + Response: &Response{ + Headers: map[string]string{ + "x-backend-server": "vllm-v0.6.3", // should passthrough + metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped + metadata.DestinationEndpointKey: "10.2.0.5:8080", // should be stripped + "content-length": "500", // hould be stripped + }, + }, + } + + results := server.generateResponseHeaders(reqCtx) + + gotHeaders := make(map[string]string) + for _, h := range results { + gotHeaders[h.Header.Key] = string(h.Header.RawValue) + } + + assert.Contains(t, gotHeaders, "x-backend-server") + assert.Contains(t, gotHeaders, "x-went-into-resp-headers") + assert.NotContains(t, gotHeaders, metadata.ObjectiveKey) + assert.NotContains(t, gotHeaders, metadata.DestinationEndpointKey) + assert.NotContains(t, gotHeaders, "content-length") +} diff --git a/pkg/epp/util/request/headers.go b/pkg/epp/util/request/headers.go index fe9ebe78dd..b2b952b6fb 100644 --- a/pkg/epp/util/request/headers.go +++ b/pkg/epp/util/request/headers.go @@ -21,12 +21,41 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" ) const ( RequestIdHeaderKey = "x-request-id" ) +var ( + // InputControlHeaders are sent by the Gateway/User to control EPP behavior. + // We must extract these, then strip them so they don't leak to the backend. + InputControlHeaders = map[string]bool{ + strings.ToLower(metadata.FlowFairnessIDKey): true, + strings.ToLower(metadata.ObjectiveKey): true, + strings.ToLower(metadata.ModelNameRewriteKey): true, + strings.ToLower(metadata.SubsetFilterKey): true, + } + + // OutputInjectionHeaders are headers EPP injects for the backend. + // If the user sends these, they must be stripped to prevent ambiguity. + OutputInjectionHeaders = map[string]bool{ + strings.ToLower(metadata.DestinationEndpointKey): true, + strings.ToLower(metadata.DestinationEndpointServedKey): true, + } + + // ProtocolHeaders are managed by the proxy layer (Envoy/EPP). + ProtocolHeaders = map[string]bool{ + "content-length": true, + } +) + +func IsSystemOwnedHeader(key string) bool { + k := strings.ToLower(key) + return InputControlHeaders[k] || OutputInjectionHeaders[k] || ProtocolHeaders[k] +} + // GetHeaderValue safely extracts the string value from an Envoy HeaderValue field. func GetHeaderValue(header *corev3.HeaderValue) string { if len(header.RawValue) > 0 {