diff --git a/pkg/bbr/handlers/request_test.go b/pkg/bbr/handlers/request_test.go index 2825e41483..c827840f1c 100644 --- a/pkg/bbr/handlers/request_test.go +++ b/pkg/bbr/handlers/request_test.go @@ -109,7 +109,7 @@ func TestHandleRequestHeaders(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}) + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{}) reqCtx := &RequestContext{ Request: &Request{Headers: make(map[string]string)}, Response: &Response{Headers: make(map[string]string)}, @@ -369,7 +369,7 @@ func TestHandleRequestBody(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - server := NewServer(test.streaming, &fakeDatastore{}, []framework.PayloadProcessor{}) + server := NewServer(test.streaming, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{}) bodyBytes, _ := json.Marshal(test.body) resp, err := server.HandleRequestBody(ctx, bodyBytes) if err != nil { @@ -407,7 +407,7 @@ func TestHandleRequestBodyWithPluginMetrics(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) noopPlugin := plugins.NewDefaultPlugin() - server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{noopPlugin}) + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{noopPlugin}, []framework.PayloadProcessor{}) bodyBytes, _ := json.Marshal(map[string]any{ "model": "bar", diff --git a/pkg/bbr/handlers/response.go b/pkg/bbr/handlers/response.go index fbcb75d651..b2c51e6762 100644 --- a/pkg/bbr/handlers/response.go +++ b/pkg/bbr/handlers/response.go @@ -17,11 +17,25 @@ limitations under the License. package handlers import ( + "context" + "encoding/json" + "fmt" + eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "sigs.k8s.io/controller-runtime/pkg/log" + + reqenvoy "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/request" ) -// HandleResponseHeaders handles response headers. -func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { +// HandleResponseHeaders extracts response headers into reqCtx and returns +// the ext-proc header response. +func (s *Server) HandleResponseHeaders(reqCtx *RequestContext, headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { + if headers != nil && headers.Headers != nil { + for _, header := range headers.Headers.Headers { + reqCtx.Response.Headers[header.Key] = reqenvoy.GetHeaderValue(header) + } + } + return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_ResponseHeaders{ @@ -31,8 +45,37 @@ func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) ([]*eppb.Proce }, nil } -// HandleResponseBody handles response bodies. -func (s *Server) HandleResponseBody(body *eppb.HttpBody) ([]*eppb.ProcessingResponse, error) { +// HandleResponseBody handles response bodies by executing response plugins in order. +func (s *Server) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, responseBodyBytes []byte) ([]*eppb.ProcessingResponse, error) { + logger := log.FromContext(ctx) + if len(s.responsePlugins) == 0 { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseBody{ + ResponseBody: &eppb.BodyResponse{}, + }, + }, + }, nil + } + + var responseBody map[string]any + if err := json.Unmarshal(responseBodyBytes, &responseBody); err != nil { + logger.Error(err, "Failed to parse response body as JSON, skipping response plugins") + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseBody{ + ResponseBody: &eppb.BodyResponse{}, + }, + }, + }, nil + } + + if err := s.executePlugins(ctx, reqCtx.Response.Headers, responseBody, s.responsePlugins); err != nil { + logger.Error(err, "Response plugin execution failed") + return nil, fmt.Errorf("failed to execute response plugins - %w", err) + } + + // TODO: apply mutated body/headers to the response (see #2449 follow-ups). return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_ResponseBody{ diff --git a/pkg/bbr/handlers/response_test.go b/pkg/bbr/handlers/response_test.go new file mode 100644 index 0000000000..1dcb3583d9 --- /dev/null +++ b/pkg/bbr/handlers/response_test.go @@ -0,0 +1,233 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "context" + "errors" + "testing" + + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging" + epp "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" +) + +const testPluginValue = "done" + +// fakeResponsePlugin implements framework.PayloadProcessor for testing response plugin execution. +type fakeResponsePlugin struct { + name string + mutateFn func(ctx context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) +} + +func (p *fakeResponsePlugin) TypedName() epp.TypedName { + return epp.TypedName{Type: "fake", Name: p.name} +} + +func (p *fakeResponsePlugin) Execute(ctx context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return p.mutateFn(ctx, headers, body) +} + +var _ framework.PayloadProcessor = &fakeResponsePlugin{} + +func newTestRequestContext() *RequestContext { + return &RequestContext{ + Request: &Request{Headers: make(map[string]string)}, + Response: &Response{Headers: make(map[string]string)}, + } +} + +func TestHandleResponseBody_NoPlugins(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{}) + responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`) + resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err != nil { + t.Fatalf("HandleResponseBody returned unexpected error: %v", err) + } + + want := []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{}, + }, + }, + } + + if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } +} + +func TestHandleResponseBody_SinglePlugin(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + mutatePlugin := &fakeResponsePlugin{ + name: "mutator", + mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + body["mutated"] = true + return headers, body, nil + }, + } + + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{mutatePlugin}) + responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`) + resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err != nil { + t.Fatalf("HandleResponseBody returned unexpected error: %v", err) + } + + // Plugins are executed but mutations are not yet applied to the response. + want := []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{}, + }, + }, + } + if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } +} + +func TestHandleResponseBody_MultiplePlugins(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + plugin1 := &fakeResponsePlugin{ + name: "plugin1", + mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + body["p1"] = testPluginValue + return headers, body, nil + }, + } + plugin2 := &fakeResponsePlugin{ + name: "plugin2", + mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + body["p2"] = testPluginValue + return headers, body, nil + }, + } + + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{plugin1, plugin2}) + responseBody := []byte(`{"original":true}`) + resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err != nil { + t.Fatalf("HandleResponseBody returned unexpected error: %v", err) + } + + // Plugins are executed but mutations are not yet applied to the response. + want := []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{}, + }, + }, + } + if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } +} + +func TestHandleResponseBody_PluginError(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + failingPlugin := &fakeResponsePlugin{ + name: "failing", + mutateFn: func(_ context.Context, _ map[string]string, _ map[string]any) (map[string]string, map[string]any, error) { + return nil, nil, errors.New("failed to execute plugin") + }, + } + + server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{failingPlugin}) + responseBody := []byte(`{"choices":[{"text":"some response"}]}`) + _, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err == nil { + t.Fatal("HandleResponseBody should have returned an error") + } + + if got := err.Error(); got == "" { + t.Error("Expected non-empty error message") + } +} + +func TestHandleResponseBody_StreamingWithPlugin(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + noopPlugin := &fakeResponsePlugin{ + name: "noop", + mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil + }, + } + + server := NewServer(true, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{noopPlugin}) + responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`) + resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody) + if err != nil { + t.Fatalf("HandleResponseBody returned unexpected error: %v", err) + } + + // Plugins are executed but mutations are not yet applied to the response. + want := []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{}, + }, + }, + } + if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } +} + +func TestProcessResponseBody_Streaming(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + + server := NewServer(true, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{}) + + chunk1 := &extProcPb.HttpBody{ + Body: []byte(`{"choices":[{"te`), + } + chunk2 := &extProcPb.HttpBody{ + Body: []byte(`xt":"Hello!"}]}`), + EndOfStream: true, + } + + reqCtx := newTestRequestContext() + respStreamedBody := &streamedBody{} + + resp1, err := server.processResponseBody(ctx, reqCtx, chunk1, respStreamedBody) + if err != nil { + t.Fatalf("processResponseBody chunk1 returned unexpected error: %v", err) + } + if resp1 != nil { + t.Fatalf("processResponseBody chunk1 should return nil while buffering, got: %v", resp1) + } + + resp2, err := server.processResponseBody(ctx, reqCtx, chunk2, respStreamedBody) + if err != nil { + t.Fatalf("processResponseBody chunk2 returned unexpected error: %v", err) + } + if resp2 == nil { + t.Fatal("processResponseBody chunk2 should return a response on EoS") + } +} diff --git a/pkg/bbr/handlers/server.go b/pkg/bbr/handlers/server.go index 9d95af7e30..8efa179192 100644 --- a/pkg/bbr/handlers/server.go +++ b/pkg/bbr/handlers/server.go @@ -23,7 +23,6 @@ import ( "time" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" @@ -38,20 +37,22 @@ type Datastore interface { GetBaseModel(modelName string) string } -func NewServer(streaming bool, ds Datastore, requestPlugins []framework.PayloadProcessor) *Server { +func NewServer(streaming bool, ds Datastore, requestPlugins []framework.PayloadProcessor, responsePlugins []framework.PayloadProcessor) *Server { return &Server{ - streaming: streaming, - ds: ds, - requestPlugins: requestPlugins, + streaming: streaming, + ds: ds, + requestPlugins: requestPlugins, + responsePlugins: responsePlugins, } } // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type Server struct { - streaming bool - ds Datastore - requestPlugins []framework.PayloadProcessor + streaming bool + ds Datastore + requestPlugins []framework.PayloadProcessor + responsePlugins []framework.PayloadProcessor } // RequestContext stores context information during the lifetime of an HTTP request. @@ -84,7 +85,8 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { Request: &Request{Headers: make(map[string]string)}, Response: &Response{Headers: make(map[string]string)}, } - streamedBody := &streamedBody{} + reqStreamedBody := &streamedBody{} + respStreamedBody := &streamedBody{} for { select { @@ -122,13 +124,18 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } else { loggerVerbose.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream) } - responses, err = s.processRequestBody(ctx, req.GetRequestBody(), streamedBody, logger) + responses, err = s.processRequestBody(ctx, req.GetRequestBody(), reqStreamedBody) case *extProcPb.ProcessingRequest_RequestTrailers: responses, err = s.HandleRequestTrailers(req.GetRequestTrailers()) case *extProcPb.ProcessingRequest_ResponseHeaders: - responses, err = s.HandleResponseHeaders(req.GetResponseHeaders()) + responses, err = s.HandleResponseHeaders(reqCtx, req.GetResponseHeaders()) case *extProcPb.ProcessingRequest_ResponseBody: - responses, err = s.HandleResponseBody(req.GetResponseBody()) + if logger.V(logutil.DEBUG).Enabled() { + logger.V(logutil.DEBUG).Info("Incoming response body chunk", "body", string(v.ResponseBody.Body), "EoS", v.ResponseBody.EndOfStream) + } else { + loggerVerbose.Info("Incoming response body chunk", "EoS", v.ResponseBody.EndOfStream) + } + responses, err = s.processResponseBody(ctx, reqCtx, req.GetResponseBody(), respStreamedBody) default: logger.V(logutil.DEFAULT).Error(nil, "Unknown Request type", "request", v) return status.Error(codes.Unknown, "unknown request type") @@ -161,8 +168,8 @@ type streamedBody struct { body []byte } -func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) { - loggerVerbose := logger.V(logutil.VERBOSE) +func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody) ([]*extProcPb.ProcessingResponse, error) { + loggerVerbose := log.FromContext(ctx).V(logutil.VERBOSE) var requestBodyBytes []byte if s.streaming { @@ -180,3 +187,23 @@ func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBod return s.HandleRequestBody(ctx, requestBodyBytes) } + +func (s *Server) processResponseBody(ctx context.Context, reqCtx *RequestContext, body *extProcPb.HttpBody, streamedRespBody *streamedBody) ([]*extProcPb.ProcessingResponse, error) { + loggerVerbose := log.FromContext(ctx).V(logutil.VERBOSE) + + var responseBodyBytes []byte + if s.streaming { + streamedRespBody.body = append(streamedRespBody.body, body.Body...) + // In the stream case, we can receive multiple response bodies. + if body.EndOfStream { + loggerVerbose.Info("Flushing response stream buffer") + responseBodyBytes = streamedRespBody.body + } else { + return nil, nil + } + } else { + responseBodyBytes = body.GetBody() + } + + return s.HandleResponseBody(ctx, reqCtx, responseBodyBytes) +} diff --git a/pkg/bbr/handlers/server_test.go b/pkg/bbr/handlers/server_test.go index 4fce40d77b..b2dc7cc454 100644 --- a/pkg/bbr/handlers/server_test.go +++ b/pkg/bbr/handlers/server_test.go @@ -24,7 +24,6 @@ import ( extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/testing/protocmp" - "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging" @@ -140,10 +139,10 @@ func TestProcessRequestBody(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - srv := NewServer(tc.streaming, &fakeDatastore{}, []framework.PayloadProcessor{}) + srv := NewServer(tc.streaming, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{}) streamedBody := &streamedBody{} for i, body := range tc.bodys { - got, err := srv.processRequestBody(context.Background(), body, streamedBody, log.FromContext(ctx)) + got, err := srv.processRequestBody(ctx, body, streamedBody) if err != nil { t.Fatalf("processRequestBody(): %v", err) } diff --git a/pkg/bbr/server/runserver.go b/pkg/bbr/server/runserver.go index b270df691f..7adead1708 100644 --- a/pkg/bbr/server/runserver.go +++ b/pkg/bbr/server/runserver.go @@ -38,11 +38,12 @@ import ( // ExtProcServerRunner provides methods to manage an external process server. type ExtProcServerRunner struct { - GrpcPort int - Datastore datastore.Datastore - SecureServing bool - Streaming bool - RequestPlugins []framework.PayloadProcessor + GrpcPort int + Datastore datastore.Datastore + SecureServing bool + Streaming bool + RequestPlugins []framework.PayloadProcessor + ResponsePlugins []framework.PayloadProcessor } func NewDefaultExtProcServerRunner(port int, streaming bool) *ExtProcServerRunner { @@ -86,7 +87,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { srv = grpc.NewServer() } - extProcPb.RegisterExternalProcessorServer(srv, handlers.NewServer(r.Streaming, r.Datastore, r.RequestPlugins)) + extProcPb.RegisterExternalProcessorServer(srv, handlers.NewServer(r.Streaming, r.Datastore, r.RequestPlugins, r.ResponsePlugins)) // Forward to the gRPC runnable. return runnable.GRPCServer("ext-proc", srv, r.GrpcPort).Start(ctx) diff --git a/test/integration/bbr/harness.go b/test/integration/bbr/harness.go index 381ac96e82..3ecc74e441 100644 --- a/test/integration/bbr/harness.go +++ b/test/integration/bbr/harness.go @@ -52,7 +52,6 @@ func NewBBRHarness(t *testing.T, ctx context.Context, streaming bool) *BBRHarnes require.NoError(t, err, "failed to acquire free port for BBR server") // 2. Configure BBR Server - // BBR is simpler than EPP; it doesn't need a K8s Manager. runner := runserver.NewDefaultExtProcServerRunner(port, false) runner.SecureServing = false runner.Streaming = streaming diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go index 3edcb70f45..bbef23dd29 100644 --- a/test/integration/bbr/hermetic_test.go +++ b/test/integration/bbr/hermetic_test.go @@ -25,6 +25,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" + "sigs.k8s.io/gateway-api-inference-extension/test/integration" )