diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 346cd4c98a..f8d2ef8e32 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -60,20 +60,10 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP switch header.Key { case metadata.FlowFairnessIDKey: reqCtx.FairnessID = reqCtx.Request.Headers[header.Key] - // remove the fairness ID header from the request headers, - // this is not data that should be manipulated or sent to the backend. - // It is only used for flow control. - delete(reqCtx.Request.Headers, header.Key) case metadata.ObjectiveKey: reqCtx.ObjectiveKey = reqCtx.Request.Headers[header.Key] - // remove the objective header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) case metadata.ModelNameRewriteKey: reqCtx.TargetModelName = reqCtx.Request.Headers[header.Key] - // remove the rewrite header from the request headers, - // this is not data that should be manipulated or sent to the backend. - delete(reqCtx.Request.Headers, header.Key) } } @@ -140,8 +130,11 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He }) } - // include all headers + // Include any non-system-owned headers. for key, value := range reqCtx.Request.Headers { + if request.IsSystemOwnedHeader(key) { + continue + } headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: key, diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go index 36b6fb6463..60d2fa194f 100644 --- a/pkg/epp/handlers/request_test.go +++ b/pkg/epp/handlers/request_test.go @@ -29,11 +29,10 @@ func TestHandleRequestHeaders(t *testing.T) { t.Parallel() tests := []struct { - name string - headers []*configPb.HeaderValue - wantHeaders map[string]string - wantFairnessID string - wantDeletedKeys []string + name string + headers []*configPb.HeaderValue + wantHeaders map[string]string + wantFairnessID string }{ { name: "Extracts Fairness ID and Removes Header", @@ -41,17 +40,15 @@ func TestHandleRequestHeaders(t *testing.T) { {Key: "x-test", Value: "val"}, {Key: metadata.FlowFairnessIDKey, Value: "user-123"}, }, - wantHeaders: map[string]string{"x-test": "val"}, - wantFairnessID: "user-123", - wantDeletedKeys: []string{metadata.FlowFairnessIDKey}, + wantHeaders: map[string]string{"x-test": "val"}, + wantFairnessID: "user-123", }, { name: "Prefers RawValue over Value", headers: []*configPb.HeaderValue{ {Key: metadata.FlowFairnessIDKey, RawValue: []byte("binary-id"), Value: "wrong-id"}, }, - wantFairnessID: "binary-id", - wantDeletedKeys: []string{metadata.FlowFairnessIDKey}, + wantFairnessID: "binary-id", }, } @@ -77,11 +74,34 @@ func TestHandleRequestHeaders(t *testing.T) { assert.Equal(t, v, reqCtx.Request.Headers[k], "Header %q should match expected value", k) } } - - for _, key := range tc.wantDeletedKeys { - _, exists := reqCtx.Request.Headers[key] - assert.False(t, exists, "Expected header %q to be removed from map", key) - } }) } } + +func TestGenerateHeaders_Sanitization(t *testing.T) { + server := &StreamingServer{} + reqCtx := &RequestContext{ + TargetEndpoint: "1.2.3.4:8080", + RequestSize: 123, + Request: &Request{ + Headers: map[string]string{ + "x-user-data": "important", // should passthrough + metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped + metadata.DestinationEndpointKey: "1.1.1.1:666", // should be stripped + "content-length": "99999", // should be stripped (re-added by logic) + }, + }, + } + + results := server.generateHeaders(reqCtx) + + gotHeaders := make(map[string]string) + for _, h := range results { + gotHeaders[h.Header.Key] = string(h.Header.RawValue) + } + + assert.Contains(t, gotHeaders, "x-user-data") + assert.NotContains(t, gotHeaders, metadata.ObjectiveKey) + assert.Equal(t, "1.2.3.4:8080", gotHeaders[metadata.DestinationEndpointKey]) + assert.Equal(t, "123", gotHeaders["Content-Length"]) +} diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index fdef8ea241..d259a7ccb3 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -155,8 +155,11 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con }, } - // include all headers + // Include any non-system-owned headers. for key, value := range reqCtx.Response.Headers { + if request.IsSystemOwnedHeader(key) { + continue + } headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: key, diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 9608d06882..1960c72523 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/assert" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -298,3 +299,30 @@ func TestHandleResponseBodyModelStreaming_TokenAccumulation(t *testing.T) { }) } } + +func TestGenerateResponseHeaders_Sanitization(t *testing.T) { + server := &StreamingServer{} + reqCtx := &RequestContext{ + Response: &Response{ + Headers: map[string]string{ + "x-backend-server": "vllm-v0.6.3", // should passthrough + metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped + metadata.DestinationEndpointKey: "10.2.0.5:8080", // should be stripped + "content-length": "500", // hould be stripped + }, + }, + } + + results := server.generateResponseHeaders(reqCtx) + + gotHeaders := make(map[string]string) + for _, h := range results { + gotHeaders[h.Header.Key] = string(h.Header.RawValue) + } + + assert.Contains(t, gotHeaders, "x-backend-server") + assert.Contains(t, gotHeaders, "x-went-into-resp-headers") + assert.NotContains(t, gotHeaders, metadata.ObjectiveKey) + assert.NotContains(t, gotHeaders, metadata.DestinationEndpointKey) + assert.NotContains(t, gotHeaders, "content-length") +} diff --git a/pkg/epp/util/request/headers.go b/pkg/epp/util/request/headers.go index fe9ebe78dd..b2b952b6fb 100644 --- a/pkg/epp/util/request/headers.go +++ b/pkg/epp/util/request/headers.go @@ -21,12 +21,41 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" ) const ( RequestIdHeaderKey = "x-request-id" ) +var ( + // InputControlHeaders are sent by the Gateway/User to control EPP behavior. + // We must extract these, then strip them so they don't leak to the backend. + InputControlHeaders = map[string]bool{ + strings.ToLower(metadata.FlowFairnessIDKey): true, + strings.ToLower(metadata.ObjectiveKey): true, + strings.ToLower(metadata.ModelNameRewriteKey): true, + strings.ToLower(metadata.SubsetFilterKey): true, + } + + // OutputInjectionHeaders are headers EPP injects for the backend. + // If the user sends these, they must be stripped to prevent ambiguity. + OutputInjectionHeaders = map[string]bool{ + strings.ToLower(metadata.DestinationEndpointKey): true, + strings.ToLower(metadata.DestinationEndpointServedKey): true, + } + + // ProtocolHeaders are managed by the proxy layer (Envoy/EPP). + ProtocolHeaders = map[string]bool{ + "content-length": true, + } +) + +func IsSystemOwnedHeader(key string) bool { + k := strings.ToLower(key) + return InputControlHeaders[k] || OutputInjectionHeaders[k] || ProtocolHeaders[k] +} + // GetHeaderValue safely extracts the string value from an Envoy HeaderValue field. func GetHeaderValue(header *corev3.HeaderValue) string { if len(header.RawValue) > 0 {