From b60e5b8fe0e0c61e298849c26874f4038ea132ac Mon Sep 17 00:00:00 2001 From: asaadbalum Date: Thu, 12 Mar 2026 14:06:16 +0200 Subject: [PATCH] [bbr] Skip body re-serialization and mutation when unchanged Add body mutation tracking to InferenceMessage via SetBody, SetBodyField, RemoveBodyField, and BodyMutated so the request and response handlers skip json.Marshal and Content-Length update when no plugin modified the body. SetBody replaces the entire body map for plugins that transform the full payload (e.g., API translation). SetBodyField and RemoveBodyField handle granular field-level mutations. For unary mode, omit BodyMutation from the ext_proc response when unchanged. For streaming mode, forward original bytes instead of re-marshaling. Add integration tests that verify body mutation tracking over real gRPC for both unary and streaming paths. Signed-off-by: Asaad Balum --- pkg/bbr/framework/types.go | 22 ++ pkg/bbr/framework/types_test.go | 103 ++++++++ pkg/bbr/handlers/request.go | 50 ++-- pkg/bbr/handlers/request_test.go | 268 +++++++++++++++------ pkg/bbr/handlers/response.go | 48 ++-- pkg/bbr/handlers/response_test.go | 112 ++++++++- pkg/bbr/handlers/server_test.go | 19 -- test/integration/bbr/body_mutation_test.go | 210 ++++++++++++++++ test/integration/bbr/harness.go | 16 +- test/integration/bbr/hermetic_test.go | 8 +- test/integration/bbr/util.go | 41 +--- 11 files changed, 711 insertions(+), 186 deletions(-) create mode 100644 pkg/bbr/framework/types_test.go create mode 100644 test/integration/bbr/body_mutation_test.go diff --git a/pkg/bbr/framework/types.go b/pkg/bbr/framework/types.go index b4d4acaf7c..2dd12524a3 100644 --- a/pkg/bbr/framework/types.go +++ b/pkg/bbr/framework/types.go @@ -37,6 +37,7 @@ type InferenceMessage struct { // mutations mutatedHeaders map[string]string removedHeaders sets.Set[string] + bodyMutated bool } func (r *InferenceMessage) SetHeader(key string, value string) { @@ -63,6 +64,27 @@ func (r *InferenceMessage) RemovedHeaders() []string { return r.removedHeaders.UnsortedList() } +func (r *InferenceMessage) SetBody(body map[string]any) { + r.Body = body + r.bodyMutated = true +} + +func (r *InferenceMessage) SetBodyField(key string, value any) { + r.Body[key] = value + r.bodyMutated = true +} + +func (r *InferenceMessage) RemoveBodyField(key string) { + if _, ok := r.Body[key]; ok { + delete(r.Body, key) + r.bodyMutated = true + } +} + +func (r *InferenceMessage) BodyMutated() bool { + return r.bodyMutated +} + type InferenceRequest struct { InferenceMessage } diff --git a/pkg/bbr/framework/types_test.go b/pkg/bbr/framework/types_test.go new file mode 100644 index 0000000000..0dc2b63afd --- /dev/null +++ b/pkg/bbr/framework/types_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package framework + +import ( + "testing" +) + +func TestSetBodyField(t *testing.T) { + msg := newInferenceMessage() + if msg.BodyMutated() { + t.Error("new message should not be marked as body-mutated") + } + + msg.SetBodyField("key", "value") + + if !msg.BodyMutated() { + t.Error("expected BodyMutated() to return true after SetBodyField") + } + if got, ok := msg.Body["key"]; !ok || got != "value" { + t.Errorf("Body[\"key\"] = %v, %v; want \"value\", true", got, ok) + } +} + +func TestSetBodyField_Overwrite(t *testing.T) { + msg := newInferenceMessage() + msg.Body["existing"] = "old" + + msg.SetBodyField("existing", "new") + + if !msg.BodyMutated() { + t.Error("expected BodyMutated() to return true after overwriting a field") + } + if got := msg.Body["existing"]; got != "new" { + t.Errorf("Body[\"existing\"] = %v; want \"new\"", got) + } +} + +func TestRemoveBodyField(t *testing.T) { + msg := newInferenceMessage() + msg.Body["key"] = "value" + + msg.RemoveBodyField("key") + + if !msg.BodyMutated() { + t.Error("expected BodyMutated() to return true after RemoveBodyField") + } + if _, ok := msg.Body["key"]; ok { + t.Error("expected key to be removed from Body") + } +} + +func TestRemoveBodyField_NonExistent(t *testing.T) { + msg := newInferenceMessage() + + msg.RemoveBodyField("missing") + + if msg.BodyMutated() { + t.Error("removing a non-existent field should not mark body as mutated") + } +} + +func TestSetBody(t *testing.T) { + msg := newInferenceMessage() + + msg.SetBody(map[string]any{"model": "llama", "prompt": "hello"}) + + if !msg.BodyMutated() { + t.Error("expected BodyMutated() to return true after SetBody") + } + if got, ok := msg.Body["model"]; !ok || got != "llama" { + t.Errorf("Body[\"model\"] = %v, %v; want \"llama\", true", got, ok) + } + if got, ok := msg.Body["prompt"]; !ok || got != "hello" { + t.Errorf("Body[\"prompt\"] = %v, %v; want \"hello\", true", got, ok) + } +} + +func TestBodyMutated_FalseByDefault(t *testing.T) { + req := NewInferenceRequest() + if req.BodyMutated() { + t.Error("new InferenceRequest should not be marked as body-mutated") + } + + resp := NewInferenceResponse() + if resp.BodyMutated() { + t.Error("new InferenceResponse should not be marked as body-mutated") + } +} diff --git a/pkg/bbr/handlers/request.go b/pkg/bbr/handlers/request.go index 80170736ce..7bb3b4c49e 100644 --- a/pkg/bbr/handlers/request.go +++ b/pkg/bbr/handlers/request.go @@ -66,12 +66,16 @@ func (s *Server) HandleRequestBody(ctx context.Context, reqCtx *RequestContext, reqCtx.Request.SetHeader(BaseModelHeader, baseModel) logger.Info("Base model from datastore", "baseModel", baseModel) - // TODO: check and do this only if the request body actually changed. - mutatedBodyBytes, err := json.Marshal(reqCtx.Request.Body) - if err != nil { - return nil, err + bodyMutated := reqCtx.Request.BodyMutated() + var mutatedBodyBytes []byte + if bodyMutated { + var err error + mutatedBodyBytes, err = json.Marshal(reqCtx.Request.Body) + if err != nil { + return nil, err + } + reqCtx.Request.SetHeader(contentLengthHeader, strconv.Itoa(len(mutatedBodyBytes))) } - reqCtx.Request.SetHeader(contentLengthHeader, strconv.Itoa(len(mutatedBodyBytes))) metrics.RecordSuccessCounter() @@ -89,27 +93,35 @@ func (s *Server) HandleRequestBody(ctx context.Context, reqCtx *RequestContext, }, }, }) - ret = addStreamedBodyResponse(ret, mutatedBodyBytes) + if bodyMutated { + ret = addStreamedBodyResponse(ret, mutatedBodyBytes) + } else { + ret = addStreamedBodyResponse(ret, requestBodyBytes) + } return ret, nil } + // Necessary so that the new headers are used in the routing decision. + response := &eppb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &eppb.HeaderMutation{ + SetHeaders: envoy.GenerateHeadersMutation(reqCtx.Request.MutatedHeaders()), + RemoveHeaders: reqCtx.Request.RemovedHeaders(), + }, + } + if bodyMutated { + response.BodyMutation = &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: mutatedBodyBytes, + }, + } + } + return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_RequestBody{ RequestBody: &eppb.BodyResponse{ - Response: &eppb.CommonResponse{ - // Necessary so that the new headers are used in the routing decision. - ClearRouteCache: true, - HeaderMutation: &eppb.HeaderMutation{ - SetHeaders: envoy.GenerateHeadersMutation(reqCtx.Request.MutatedHeaders()), - RemoveHeaders: reqCtx.Request.RemovedHeaders(), - }, - BodyMutation: &eppb.BodyMutation{ - Mutation: &eppb.BodyMutation_Body{ - Body: mutatedBodyBytes, - }, - }, - }, + Response: response, }, }, }, diff --git a/pkg/bbr/handlers/request_test.go b/pkg/bbr/handlers/request_test.go index a43c4b04c2..c506037869 100644 --- a/pkg/bbr/handlers/request_test.go +++ b/pkg/bbr/handlers/request_test.go @@ -35,6 +35,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins" envoytest "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/test" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging" + epp "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" ) func TestHandleRequestHeaders(t *testing.T) { @@ -212,47 +213,33 @@ func TestHandleRequestBody(t *testing.T) { "model": 1, "prompt": "Tell me a joke", }, - want: func() []*extProcPb.ProcessingResponse { - b, _ := json.Marshal(map[string]any{"model": 1, "prompt": "Tell me a joke"}) - return []*extProcPb.ProcessingResponse{ - { - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*basepb.HeaderValueOption{ - { - Header: &basepb.HeaderValue{ - Key: ModelHeader, - RawValue: []byte("1"), - }, - }, - { - Header: &basepb.HeaderValue{ - Key: BaseModelHeader, - RawValue: []byte(""), - }, - }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: ModelHeader, + RawValue: []byte("1"), }, }, - }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: b, + { + Header: &basepb.HeaderValue{ + Key: BaseModelHeader, + RawValue: []byte(""), + }, }, }, }, }, }, }, - } - }(), + }, + }, }, { name: "success", @@ -260,47 +247,33 @@ func TestHandleRequestBody(t *testing.T) { "model": "foo", "prompt": "Tell me a joke", }, - want: func() []*extProcPb.ProcessingResponse { - b, _ := json.Marshal(map[string]any{"model": "foo", "prompt": "Tell me a joke"}) - return []*extProcPb.ProcessingResponse{ - { - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*basepb.HeaderValueOption{ - { - Header: &basepb.HeaderValue{ - Key: ModelHeader, - RawValue: []byte("foo"), - }, - }, - { - Header: &basepb.HeaderValue{ - Key: BaseModelHeader, - RawValue: []byte(""), - }, - }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: ModelHeader, + RawValue: []byte("foo"), }, }, - }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: b, + { + Header: &basepb.HeaderValue{ + Key: BaseModelHeader, + RawValue: []byte(""), + }, }, }, }, }, }, }, - } - }(), + }, + }, }, { name: "success-with-streaming", @@ -331,12 +304,6 @@ func TestHandleRequestBody(t *testing.T) { RawValue: []byte(""), }, }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, }, }, }, @@ -398,12 +365,6 @@ func TestHandleRequestBody(t *testing.T) { RawValue: []byte(""), }, }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, }, }, }, @@ -541,10 +502,159 @@ func TestHandleRequestBodyWithPluginMetrics(t *testing.T) { } func mapToBytes(t *testing.T, m map[string]any) []byte { - // Convert map to JSON byte array bytes, err := json.Marshal(m) if err != nil { t.Fatalf("Marshal(): %v", err) } return bytes } + +type bodyMutatingPlugin struct { + name string + mutateFn func(ctx context.Context, request *framework.InferenceRequest) error +} + +func (p *bodyMutatingPlugin) TypedName() epp.TypedName { + return epp.TypedName{Type: "fake", Name: p.name} +} + +func (p *bodyMutatingPlugin) ProcessRequest(ctx context.Context, request *framework.InferenceRequest) error { + return p.mutateFn(ctx, request) +} + +var _ framework.RequestProcessor = &bodyMutatingPlugin{} + +func TestHandleRequestBody_BodyMutation(t *testing.T) { + metrics.Register() + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + plugin := &bodyMutatingPlugin{ + name: "body-mutator", + mutateFn: func(_ context.Context, request *framework.InferenceRequest) error { + request.SetBodyField("injected", "value") + return nil + }, + } + + tests := []struct { + name string + streaming bool + body map[string]any + want []*extProcPb.ProcessingResponse + }{ + { + name: "unary with body mutation", + body: map[string]any{ + "prompt": "test", + }, + want: func() []*extProcPb.ProcessingResponse { + b, _ := json.Marshal(map[string]any{"prompt": "test", "injected": "value"}) + return []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: BaseModelHeader, + RawValue: []byte(""), + }, + }, + { + Header: &basepb.HeaderValue{ + Key: contentLengthHeader, + RawValue: []byte(strconv.Itoa(len(b))), + }, + }, + }, + }, + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_Body{ + Body: b, + }, + }, + }, + }, + }, + }, + } + }(), + }, + { + name: "streaming with body mutation", + streaming: true, + body: map[string]any{ + "prompt": "test", + }, + want: func() []*extProcPb.ProcessingResponse { + b, _ := json.Marshal(map[string]any{"prompt": "test", "injected": "value"}) + return []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: BaseModelHeader, + RawValue: []byte(""), + }, + }, + { + Header: &basepb.HeaderValue{ + Key: contentLengthHeader, + RawValue: []byte(strconv.Itoa(len(b))), + }, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: b, + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + } + }(), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := NewServer(tc.streaming, &fakeDatastore{}, []framework.RequestProcessor{plugin}, []framework.ResponseProcessor{}) + reqCtx := &RequestContext{ + Request: framework.NewInferenceRequest(), + } + bodyBytes, _ := json.Marshal(tc.body) + resp, err := server.HandleRequestBody(ctx, reqCtx, bodyBytes) + if err != nil { + t.Fatalf("HandleRequestBody returned unexpected error: %v", err) + } + + envoytest.SortSetHeadersInResponses(tc.want) + envoytest.SortSetHeadersInResponses(resp) + if diff := cmp.Diff(tc.want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleRequestBody returned unexpected response, diff(-want, +got): %v", diff) + } + }) + } +} diff --git a/pkg/bbr/handlers/response.go b/pkg/bbr/handlers/response.go index 05a79d7d4d..07b0d8bcc2 100644 --- a/pkg/bbr/handlers/response.go +++ b/pkg/bbr/handlers/response.go @@ -78,13 +78,17 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, return nil, fmt.Errorf("failed to execute response plugins - %w", err) } - mutatedBytes, err := json.Marshal(reqCtx.Response.Body) - if err != nil { - return nil, fmt.Errorf("failed to marshal mutated response body - %w", err) + bodyMutated := reqCtx.Response.BodyMutated() + var mutatedBytes []byte + if bodyMutated { + var err error + mutatedBytes, err = json.Marshal(reqCtx.Response.Body) + if err != nil { + return nil, fmt.Errorf("failed to marshal mutated response body - %w", err) + } + reqCtx.Response.SetHeader(contentLengthHeader, strconv.Itoa(len(mutatedBytes))) } - reqCtx.Response.SetHeader(contentLengthHeader, strconv.Itoa(len(mutatedBytes))) - if s.streaming { var ret []*eppb.ProcessingResponse ret = append(ret, &eppb.ProcessingResponse{ @@ -100,26 +104,34 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, }, }, }) - ret = envoy.AddStreamedResponseBody(ret, mutatedBytes) + if bodyMutated { + ret = envoy.AddStreamedResponseBody(ret, mutatedBytes) + } else { + ret = envoy.AddStreamedResponseBody(ret, responseBodyBytes) + } return ret, nil } + response := &eppb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &eppb.HeaderMutation{ + SetHeaders: envoy.GenerateHeadersMutation(reqCtx.Response.MutatedHeaders()), + RemoveHeaders: reqCtx.Response.RemovedHeaders(), + }, + } + if bodyMutated { + response.BodyMutation = &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: mutatedBytes, + }, + } + } + return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_ResponseBody{ ResponseBody: &eppb.BodyResponse{ - Response: &eppb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &eppb.HeaderMutation{ - SetHeaders: envoy.GenerateHeadersMutation(reqCtx.Response.MutatedHeaders()), - RemoveHeaders: reqCtx.Response.RemovedHeaders(), - }, - BodyMutation: &eppb.BodyMutation{ - Mutation: &eppb.BodyMutation_Body{ - Body: mutatedBytes, - }, - }, - }, + Response: response, }, }, }, diff --git a/pkg/bbr/handlers/response_test.go b/pkg/bbr/handlers/response_test.go index 2610fe7c29..0ae7c5c54f 100644 --- a/pkg/bbr/handlers/response_test.go +++ b/pkg/bbr/handlers/response_test.go @@ -88,7 +88,7 @@ func TestHandleResponseBody_SinglePlugin(t *testing.T) { mutatePlugin := &fakeResponsePlugin{ name: "mutator", mutateFn: func(_ context.Context, response *framework.InferenceResponse) error { - response.Body["mutated"] = true + response.SetBodyField("mutated", true) return nil }, } @@ -121,14 +121,14 @@ func TestHandleResponseBody_MultiplePlugins(t *testing.T) { plugin1 := &fakeResponsePlugin{ name: "plugin1", mutateFn: func(_ context.Context, response *framework.InferenceResponse) error { - response.Body["p1"] = testPluginValue + response.SetBodyField("p1", testPluginValue) return nil }, } plugin2 := &fakeResponsePlugin{ name: "plugin2", mutateFn: func(_ context.Context, response *framework.InferenceResponse) error { - response.Body["p2"] = testPluginValue + response.SetBodyField("p2", testPluginValue) return nil }, } @@ -184,7 +184,7 @@ func TestHandleResponseBody_StreamingWithPlugin(t *testing.T) { mutatePlugin := &fakeResponsePlugin{ name: "mutator", mutateFn: func(_ context.Context, response *framework.InferenceResponse) error { - response.Body["mutated"] = true + response.SetBodyField("mutated", true) return nil }, } @@ -242,6 +242,110 @@ func TestProcessResponseBody_Streaming(t *testing.T) { } } +func TestHandleResponseBody_PluginNoBodyMutation(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + headerOnlyPlugin := &fakeResponsePlugin{ + name: "header-only", + mutateFn: func(_ context.Context, response *framework.InferenceResponse) error { + response.SetHeader("X-Custom-Response", "added") + return nil + }, + } + + tests := []struct { + name string + streaming bool + want []*extProcPb.ProcessingResponse + }{ + { + name: "unary - header-only plugin skips body mutation", + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: "X-Custom-Response", + RawValue: []byte("added"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "streaming - header-only plugin passes original body", + streaming: true, + want: func() []*extProcPb.ProcessingResponse { + responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`) + return []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: "X-Custom-Response", + RawValue: []byte("added"), + }, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: responseBody, + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + } + }(), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := NewServer(tc.streaming, &fakeDatastore{}, []framework.RequestProcessor{}, []framework.ResponseProcessor{headerOnlyPlugin}) + responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`) + resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err != nil { + t.Fatalf("HandleResponseBody returned unexpected error: %v", err) + } + + envoytest.SortSetHeadersInResponses(tc.want) + envoytest.SortSetHeadersInResponses(resp) + if diff := cmp.Diff(tc.want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } + }) + } +} + // expectedResponseBodyMutation builds the expected unary response for a mutated body, // including the content-length header mutation. func expectedResponseBodyMutation(bodyBytes []byte) *extProcPb.ProcessingResponse { diff --git a/pkg/bbr/handlers/server_test.go b/pkg/bbr/handlers/server_test.go index b1c787eabc..e9f07fd86a 100644 --- a/pkg/bbr/handlers/server_test.go +++ b/pkg/bbr/handlers/server_test.go @@ -19,7 +19,6 @@ package handlers import ( "context" "encoding/json" - "strconv" "testing" basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -51,7 +50,6 @@ func TestHandleRequestBodyStreaming(t *testing.T) { Response: &extProcPb.ProcessingResponse_RequestBody{ RequestBody: &extProcPb.BodyResponse{ Response: &extProcPb.CommonResponse{ - // Necessary so that the new headers are used in the routing decision. ClearRouteCache: true, HeaderMutation: &extProcPb.HeaderMutation{ SetHeaders: []*basepb.HeaderValueOption{ @@ -67,17 +65,6 @@ func TestHandleRequestBodyStreaming(t *testing.T) { RawValue: []byte(""), }, }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, - }, - }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: b, }, }, }, @@ -110,12 +97,6 @@ func TestHandleRequestBodyStreaming(t *testing.T) { RawValue: []byte(""), }, }, - { - Header: &basepb.HeaderValue{ - Key: contentLengthHeader, - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, }, }, }, diff --git a/test/integration/bbr/body_mutation_test.go b/test/integration/bbr/body_mutation_test.go new file mode 100644 index 0000000000..37360d45af --- /dev/null +++ b/test/integration/bbr/body_mutation_test.go @@ -0,0 +1,210 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bbr + +import ( + "context" + "encoding/json" + "strconv" + "testing" + + envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" + envoytest "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/test" + epp "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/test/integration" +) + +// bodyMutatingPlugin is a test plugin that injects a field into the request body. +type bodyMutatingPlugin struct { + fieldName string + fieldValue any +} + +func (p *bodyMutatingPlugin) TypedName() epp.TypedName { + return epp.TypedName{Type: "test-body-mutator", Name: "test-body-mutator"} +} + +func (p *bodyMutatingPlugin) ProcessRequest(_ context.Context, request *framework.InferenceRequest) error { + request.SetBodyField(p.fieldName, p.fieldValue) + return nil +} + +var _ framework.RequestProcessor = &bodyMutatingPlugin{} + +// TestBodyMutation_Unary verifies that when a plugin mutates the body, the response +// includes a BodyMutation with the serialized body and an updated Content-Length header. +func TestBodyMutation_Unary(t *testing.T) { + t.Parallel() + ctx := context.Background() + + plugin := &bodyMutatingPlugin{fieldName: "injected", fieldValue: "test-value"} + h := NewBBRHarnessWithPlugins(t, ctx, false, []framework.RequestProcessor{plugin}) + + body := map[string]any{"prompt": "hello"} + bodyBytes, _ := json.Marshal(body) + + req := &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{ + Body: bodyBytes, + EndOfStream: true, + }, + }, + } + + resp, err := integration.SendRequest(t, h.Client, req) + require.NoError(t, err, "unexpected error during request processing") + + wantBody, _ := json.Marshal(map[string]any{ + "prompt": "hello", + "injected": "test-value", + }) + want := &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*envoyCorev3.HeaderValueOption{ + { + Header: &envoyCorev3.HeaderValue{ + Key: "X-Gateway-Base-Model-Name", + RawValue: []byte(""), + }, + }, + { + Header: &envoyCorev3.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(len(wantBody))), + }, + }, + }, + }, + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_Body{ + Body: wantBody, + }, + }, + }, + }, + }, + } + + envoytest.SortSetHeadersInResponses([]*extProcPb.ProcessingResponse{want}) + envoytest.SortSetHeadersInResponses([]*extProcPb.ProcessingResponse{resp}) + if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" { + t.Errorf("Response mismatch (-want +got): %v", diff) + } +} + +// TestBodyMutation_Streaming verifies the streaming path: when a plugin mutates the body, +// the header response includes Content-Length and the body response carries the mutated bytes. +func TestBodyMutation_Streaming(t *testing.T) { + t.Parallel() + ctx := context.Background() + + plugin := &bodyMutatingPlugin{fieldName: "injected", fieldValue: "test-value"} + h := NewBBRHarnessWithPlugins(t, ctx, true, []framework.RequestProcessor{plugin}) + + body := map[string]any{"prompt": "hello"} + bodyBytes, _ := json.Marshal(body) + + reqs := []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &envoyCorev3.HeaderMap{ + Headers: []*envoyCorev3.HeaderValue{ + {Key: "content-type", RawValue: []byte("application/json")}, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{ + Body: bodyBytes, + EndOfStream: true, + }, + }, + }, + } + + wantBody, _ := json.Marshal(map[string]any{ + "prompt": "hello", + "injected": "test-value", + }) + wantResponses := []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*envoyCorev3.HeaderValueOption{ + { + Header: &envoyCorev3.HeaderValue{ + Key: "X-Gateway-Base-Model-Name", + RawValue: []byte(""), + }, + }, + { + Header: &envoyCorev3.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(len(wantBody))), + }, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: wantBody, + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + } + + responses, err := integration.StreamedRequest(t, h.Client, reqs, len(wantResponses)) + require.NoError(t, err, "unexpected stream error") + + envoytest.SortSetHeadersInResponses(wantResponses) + envoytest.SortSetHeadersInResponses(responses) + if diff := cmp.Diff(wantResponses, responses, protocmp.Transform()); diff != "" { + t.Errorf("Response mismatch (-want +got): %v", diff) + } +} diff --git a/test/integration/bbr/harness.go b/test/integration/bbr/harness.go index 7558e2ac0d..e7c15ae093 100644 --- a/test/integration/bbr/harness.go +++ b/test/integration/bbr/harness.go @@ -45,10 +45,18 @@ type BBRHarness struct { grpcConn *grpc.ClientConn } -// NewBBRHarness boots up an isolated BBR server on a random port. -// streaming: determines if the BBR server runs in streaming mode or unary/buffered mode. +// NewBBRHarness boots up an isolated BBR server on a random port with the default +// BodyFieldToHeaderPlugin for model extraction. func NewBBRHarness(t *testing.T, ctx context.Context, streaming bool) *BBRHarness { t.Helper() + modelToHeaderPlugin, err := plugins.NewBodyFieldToHeaderPlugin(handlers.ModelField, handlers.ModelHeader) + require.NoError(t, err, "failed to create body-field-to-header plugin") + return NewBBRHarnessWithPlugins(t, ctx, streaming, []framework.RequestProcessor{modelToHeaderPlugin}) +} + +// NewBBRHarnessWithPlugins boots up an isolated BBR server with custom request plugins. +func NewBBRHarnessWithPlugins(t *testing.T, ctx context.Context, streaming bool, requestPlugins []framework.RequestProcessor) *BBRHarness { + t.Helper() // 1. Allocate Free Port port, err := integration.GetFreePort() @@ -59,9 +67,7 @@ func NewBBRHarness(t *testing.T, ctx context.Context, streaming bool) *BBRHarnes runner.SecureServing = false runner.Streaming = streaming runner.Datastore = datastore.NewDatastore() - modelToHeaderPlugin, err := plugins.NewBodyFieldToHeaderPlugin(handlers.ModelField, handlers.ModelHeader) - require.NoError(t, err, "failed to create body-field-to-header plugin") - runner.RequestPlugins = []framework.RequestProcessor{modelToHeaderPlugin} + runner.RequestPlugins = requestPlugins // 3. Start Server in Background serverCtx, serverCancel := context.WithCancel(ctx) diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go index 369e057467..203a961030 100644 --- a/test/integration/bbr/hermetic_test.go +++ b/test/integration/bbr/hermetic_test.go @@ -44,13 +44,13 @@ func TestBodyBasedRouting(t *testing.T) { { name: "success: extracts model and sets header", req: integration.ReqLLMUnary(logger, "test", "llama"), - wantResponse: ExpectBBRUnaryResponse("llama", "test"), + wantResponse: ExpectBBRUnaryResponse("llama"), wantErr: false, }, { name: "noop: no model parameter in body", req: integration.ReqLLMUnary(logger, "test1", ""), - wantResponse: ExpectBBRUnaryResponse("", ""), // Expect no headers. + wantResponse: ExpectBBRUnaryResponse(""), // Expect no headers. wantErr: false, }, } @@ -95,7 +95,7 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) { name: "success: adds model header from simple body", reqs: integration.ReqLLM(logger, "test", "foo", "bar"), wantResponses: []*extProcPb.ProcessingResponse{ - ExpectBBRHeader("foo", "test"), + ExpectBBRHeader("foo"), ExpectBBRBodyPassThrough("test", "foo"), }, }, @@ -107,7 +107,7 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) { `ra-sheddable","prompt":"test","temperature":0}`, ), wantResponses: []*extProcPb.ProcessingResponse{ - ExpectBBRHeader("sql-lora-sheddable", "test"), + ExpectBBRHeader("sql-lora-sheddable"), ExpectBBRBodyPassThrough("test", "sql-lora-sheddable"), }, }, diff --git a/test/integration/bbr/util.go b/test/integration/bbr/util.go index 92d8aef547..eae5a484af 100644 --- a/test/integration/bbr/util.go +++ b/test/integration/bbr/util.go @@ -18,7 +18,6 @@ package bbr import ( "encoding/json" - "strconv" envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" @@ -27,9 +26,7 @@ import ( // --- Response Expectations (Streaming) --- // ExpectBBRHeader asserts that BBR set the specific model header and cleared the route cache. -func ExpectBBRHeader(modelName, prompt string) *extProcPb.ProcessingResponse { - b := marshalExpectedBody(prompt, modelName) - +func ExpectBBRHeader(modelName string) *extProcPb.ProcessingResponse { return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_RequestHeaders{ RequestHeaders: &extProcPb.HeadersResponse{ @@ -37,12 +34,6 @@ func ExpectBBRHeader(modelName, prompt string) *extProcPb.ProcessingResponse { ClearRouteCache: true, HeaderMutation: &extProcPb.HeaderMutation{ SetHeaders: []*envoyCorev3.HeaderValueOption{ - { - Header: &envoyCorev3.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, { Header: &envoyCorev3.HeaderValue{ Key: "X-Gateway-Model-Name", @@ -103,26 +94,17 @@ func ExpectBBRNoOpHeader() *extProcPb.ProcessingResponse { // --- Response Expectations (Unary) --- -// ExpectBBRUnaryResponse creates expected response for unary tests where the body is mutated directly. -func ExpectBBRUnaryResponse(modelName, prompt string) *extProcPb.ProcessingResponse { +// ExpectBBRUnaryResponse creates expected response for unary tests. +func ExpectBBRUnaryResponse(modelName string) *extProcPb.ProcessingResponse { resp := &extProcPb.ProcessingResponse{} - // If modelName is present, we expect header mutations and body mutation. if modelName != "" { - b := marshalExpectedBody(prompt, modelName) - resp.Response = &extProcPb.ProcessingResponse_RequestBody{ RequestBody: &extProcPb.BodyResponse{ Response: &extProcPb.CommonResponse{ ClearRouteCache: true, HeaderMutation: &extProcPb.HeaderMutation{ SetHeaders: []*envoyCorev3.HeaderValueOption{ - { - Header: &envoyCorev3.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(len(b))), - }, - }, { Header: &envoyCorev3.HeaderValue{ Key: "X-Gateway-Model-Name", @@ -137,30 +119,13 @@ func ExpectBBRUnaryResponse(modelName, prompt string) *extProcPb.ProcessingRespo }, }, }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: b, - }, - }, }, }, } } else { - // Otherwise, expect a No-Op on the body. resp.Response = &extProcPb.ProcessingResponse_RequestBody{ RequestBody: &extProcPb.BodyResponse{}, } } return resp } - -func marshalExpectedBody(prompt, model string) []byte { - j := map[string]any{ - "max_tokens": 100, "prompt": prompt, "temperature": 0, - } - if model != "" { - j["model"] = model - } - b, _ := json.Marshal(j) - return b -}