Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,10 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP
switch header.Key {
case metadata.FlowFairnessIDKey:
reqCtx.FairnessID = reqCtx.Request.Headers[header.Key]
// remove the fairness ID header from the request headers,
// this is not data that should be manipulated or sent to the backend.
// It is only used for flow control.
delete(reqCtx.Request.Headers, header.Key)
case metadata.ObjectiveKey:
reqCtx.ObjectiveKey = reqCtx.Request.Headers[header.Key]
// remove the objective header from the request headers,
// this is not data that should be manipulated or sent to the backend.
delete(reqCtx.Request.Headers, header.Key)
case metadata.ModelNameRewriteKey:
reqCtx.TargetModelName = reqCtx.Request.Headers[header.Key]
// remove the rewrite header from the request headers,
// this is not data that should be manipulated or sent to the backend.
delete(reqCtx.Request.Headers, header.Key)
}
}

Expand Down Expand Up @@ -140,8 +130,11 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He
})
}

// include all headers
// Include any non-system-owned headers.
for key, value := range reqCtx.Request.Headers {
if request.IsSystemOwnedHeader(key) {
continue
}
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
Expand Down
50 changes: 35 additions & 15 deletions pkg/epp/handlers/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,26 @@ func TestHandleRequestHeaders(t *testing.T) {
t.Parallel()

tests := []struct {
name string
headers []*configPb.HeaderValue
wantHeaders map[string]string
wantFairnessID string
wantDeletedKeys []string
name string
headers []*configPb.HeaderValue
wantHeaders map[string]string
wantFairnessID string
}{
{
name: "Extracts Fairness ID and Removes Header",
headers: []*configPb.HeaderValue{
{Key: "x-test", Value: "val"},
{Key: metadata.FlowFairnessIDKey, Value: "user-123"},
},
wantHeaders: map[string]string{"x-test": "val"},
wantFairnessID: "user-123",
wantDeletedKeys: []string{metadata.FlowFairnessIDKey},
wantHeaders: map[string]string{"x-test": "val"},
wantFairnessID: "user-123",
},
{
name: "Prefers RawValue over Value",
headers: []*configPb.HeaderValue{
{Key: metadata.FlowFairnessIDKey, RawValue: []byte("binary-id"), Value: "wrong-id"},
},
wantFairnessID: "binary-id",
wantDeletedKeys: []string{metadata.FlowFairnessIDKey},
wantFairnessID: "binary-id",
},
}

Expand All @@ -77,11 +74,34 @@ func TestHandleRequestHeaders(t *testing.T) {
assert.Equal(t, v, reqCtx.Request.Headers[k], "Header %q should match expected value", k)
}
}

for _, key := range tc.wantDeletedKeys {
_, exists := reqCtx.Request.Headers[key]
assert.False(t, exists, "Expected header %q to be removed from map", key)
}
})
}
}

func TestGenerateHeaders_Sanitization(t *testing.T) {
server := &StreamingServer{}
reqCtx := &RequestContext{
TargetEndpoint: "1.2.3.4:8080",
RequestSize: 123,
Request: &Request{
Headers: map[string]string{
"x-user-data": "important", // should passthrough
metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped
metadata.DestinationEndpointKey: "1.1.1.1:666", // should be stripped
"content-length": "99999", // should be stripped (re-added by logic)
},
},
}

results := server.generateHeaders(reqCtx)

gotHeaders := make(map[string]string)
for _, h := range results {
gotHeaders[h.Header.Key] = string(h.Header.RawValue)
}

assert.Contains(t, gotHeaders, "x-user-data")
assert.NotContains(t, gotHeaders, metadata.ObjectiveKey)
assert.Equal(t, "1.2.3.4:8080", gotHeaders[metadata.DestinationEndpointKey])
assert.Equal(t, "123", gotHeaders["Content-Length"])
}
5 changes: 4 additions & 1 deletion pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
},
}

// include all headers
// Include any non-system-owned headers.
for key, value := range reqCtx.Response.Headers {
if request.IsSystemOwnedHeader(key) {
continue
}
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
Expand Down
28 changes: 28 additions & 0 deletions pkg/epp/handlers/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/assert"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -298,3 +299,30 @@ func TestHandleResponseBodyModelStreaming_TokenAccumulation(t *testing.T) {
})
}
}

func TestGenerateResponseHeaders_Sanitization(t *testing.T) {
server := &StreamingServer{}
reqCtx := &RequestContext{
Response: &Response{
Headers: map[string]string{
"x-backend-server": "vllm-v0.6.3", // should passthrough
metadata.ObjectiveKey: "sensitive-objective-id", // should be stripped
metadata.DestinationEndpointKey: "10.2.0.5:8080", // should be stripped
"content-length": "500", // hould be stripped
},
},
}

results := server.generateResponseHeaders(reqCtx)

gotHeaders := make(map[string]string)
for _, h := range results {
gotHeaders[h.Header.Key] = string(h.Header.RawValue)
}

assert.Contains(t, gotHeaders, "x-backend-server")
assert.Contains(t, gotHeaders, "x-went-into-resp-headers")
assert.NotContains(t, gotHeaders, metadata.ObjectiveKey)
assert.NotContains(t, gotHeaders, metadata.DestinationEndpointKey)
assert.NotContains(t, gotHeaders, "content-length")
}
29 changes: 29 additions & 0 deletions pkg/epp/util/request/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,41 @@ import (

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
)

const (
RequestIdHeaderKey = "x-request-id"
)

var (
// InputControlHeaders are sent by the Gateway/User to control EPP behavior.
// We must extract these, then strip them so they don't leak to the backend.
InputControlHeaders = map[string]bool{
Copy link
Copy Markdown
Contributor

@nirrozenbaum nirrozenbaum Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there’s no meaning for the “true” here, and this map acts as a hash set.
could be nice to use “set” from k8s utils instead of this map, and use the set.Has function (it’s backed by a map to struct{}).

it reads clearer IMO.
(if you want to push in a follow up)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example here:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this but figured map[string]bool or map[string]struct{} was idiomatic enough. The set utility reads nicer though, especially now that we have generics. Will send out a small followup!

strings.ToLower(metadata.FlowFairnessIDKey): true,
strings.ToLower(metadata.ObjectiveKey): true,
strings.ToLower(metadata.ModelNameRewriteKey): true,
strings.ToLower(metadata.SubsetFilterKey): true,
}

// OutputInjectionHeaders are headers EPP injects for the backend.
// If the user sends these, they must be stripped to prevent ambiguity.
OutputInjectionHeaders = map[string]bool{
strings.ToLower(metadata.DestinationEndpointKey): true,
strings.ToLower(metadata.DestinationEndpointServedKey): true,
}

// ProtocolHeaders are managed by the proxy layer (Envoy/EPP).
ProtocolHeaders = map[string]bool{
"content-length": true,
}
)

func IsSystemOwnedHeader(key string) bool {
k := strings.ToLower(key)
return InputControlHeaders[k] || OutputInjectionHeaders[k] || ProtocolHeaders[k]
}

// GetHeaderValue safely extracts the string value from an Envoy HeaderValue field.
func GetHeaderValue(header *corev3.HeaderValue) string {
if len(header.RawValue) > 0 {
Expand Down