Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/bbr/handlers/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestHandleRequestHeaders(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{})
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{})
reqCtx := &RequestContext{
Request: &Request{Headers: make(map[string]string)},
Response: &Response{Headers: make(map[string]string)},
Expand Down Expand Up @@ -369,7 +369,7 @@ func TestHandleRequestBody(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := NewServer(test.streaming, &fakeDatastore{}, []framework.PayloadProcessor{})
server := NewServer(test.streaming, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{})
bodyBytes, _ := json.Marshal(test.body)
resp, err := server.HandleRequestBody(ctx, bodyBytes)
if err != nil {
Expand Down Expand Up @@ -407,7 +407,7 @@ func TestHandleRequestBodyWithPluginMetrics(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

noopPlugin := plugins.NewDefaultPlugin()
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{noopPlugin})
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{noopPlugin}, []framework.PayloadProcessor{})

bodyBytes, _ := json.Marshal(map[string]any{
"model": "bar",
Expand Down
51 changes: 47 additions & 4 deletions pkg/bbr/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ limitations under the License.
package handlers

import (
"context"
"encoding/json"
"fmt"

eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"sigs.k8s.io/controller-runtime/pkg/log"

reqenvoy "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/request"
)

// HandleResponseHeaders handles response headers.
func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) {
// HandleResponseHeaders extracts response headers into reqCtx and returns
// the ext-proc header response.
func (s *Server) HandleResponseHeaders(reqCtx *RequestContext, headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) {
if headers != nil && headers.Headers != nil {
for _, header := range headers.Headers.Headers {
reqCtx.Response.Headers[header.Key] = reqenvoy.GetHeaderValue(header)
}
}

return []*eppb.ProcessingResponse{
{
Response: &eppb.ProcessingResponse_ResponseHeaders{
Expand All @@ -31,8 +45,37 @@ func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) ([]*eppb.Proce
}, nil
}

// HandleResponseBody handles response bodies.
func (s *Server) HandleResponseBody(body *eppb.HttpBody) ([]*eppb.ProcessingResponse, error) {
// HandleResponseBody handles response bodies by executing response plugins in order.
func (s *Server) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, responseBodyBytes []byte) ([]*eppb.ProcessingResponse, error) {
logger := log.FromContext(ctx)
if len(s.responsePlugins) == 0 {
return []*eppb.ProcessingResponse{
{
Response: &eppb.ProcessingResponse_ResponseBody{
ResponseBody: &eppb.BodyResponse{},
},
},
}, nil
}

var responseBody map[string]any
if err := json.Unmarshal(responseBodyBytes, &responseBody); err != nil {
logger.Error(err, "Failed to parse response body as JSON, skipping response plugins")
return []*eppb.ProcessingResponse{
{
Response: &eppb.ProcessingResponse_ResponseBody{
ResponseBody: &eppb.BodyResponse{},
},
},
}, nil
}

if err := s.executePlugins(ctx, reqCtx.Response.Headers, responseBody, s.responsePlugins); err != nil {
logger.Error(err, "Response plugin execution failed")
return nil, fmt.Errorf("failed to execute response plugins - %w", err)
}

// TODO: apply mutated body/headers to the response (see #2449 follow-ups).
return []*eppb.ProcessingResponse{
{
Response: &eppb.ProcessingResponse_ResponseBody{
Expand Down
233 changes: 233 additions & 0 deletions pkg/bbr/handlers/response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package handlers

import (
"context"
"errors"
"testing"

extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/testing/protocmp"

"sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
epp "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
)

const testPluginValue = "done"

// fakeResponsePlugin implements framework.PayloadProcessor for testing response plugin execution.
type fakeResponsePlugin struct {
name string
mutateFn func(ctx context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)
}

func (p *fakeResponsePlugin) TypedName() epp.TypedName {
return epp.TypedName{Type: "fake", Name: p.name}
}

func (p *fakeResponsePlugin) Execute(ctx context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return p.mutateFn(ctx, headers, body)
}

var _ framework.PayloadProcessor = &fakeResponsePlugin{}

func newTestRequestContext() *RequestContext {
return &RequestContext{
Request: &Request{Headers: make(map[string]string)},
Response: &Response{Headers: make(map[string]string)},
}
}

func TestHandleResponseBody_NoPlugins(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{})
responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`)
resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody)
if err != nil {
t.Fatalf("HandleResponseBody returned unexpected error: %v", err)
}

want := []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{},
},
},
}

if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" {
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
}
}

func TestHandleResponseBody_SinglePlugin(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

mutatePlugin := &fakeResponsePlugin{
name: "mutator",
mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
body["mutated"] = true
return headers, body, nil
},
}

server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{mutatePlugin})
responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`)
resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody)
if err != nil {
t.Fatalf("HandleResponseBody returned unexpected error: %v", err)
}

// Plugins are executed but mutations are not yet applied to the response.
want := []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{},
},
},
}
if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" {
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
}
}

func TestHandleResponseBody_MultiplePlugins(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

plugin1 := &fakeResponsePlugin{
name: "plugin1",
mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
body["p1"] = testPluginValue
return headers, body, nil
},
}
plugin2 := &fakeResponsePlugin{
name: "plugin2",
mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
body["p2"] = testPluginValue
return headers, body, nil
},
}

server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{plugin1, plugin2})
responseBody := []byte(`{"original":true}`)
resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody)
if err != nil {
t.Fatalf("HandleResponseBody returned unexpected error: %v", err)
}

// Plugins are executed but mutations are not yet applied to the response.
want := []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{},
},
},
}
if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" {
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
}
}

func TestHandleResponseBody_PluginError(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

failingPlugin := &fakeResponsePlugin{
name: "failing",
mutateFn: func(_ context.Context, _ map[string]string, _ map[string]any) (map[string]string, map[string]any, error) {
return nil, nil, errors.New("failed to execute plugin")
},
}
Comment on lines +153 to +158
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error in the plugin interface is not intended to be guardrail. that's confusing.
the intention in the error return value was for a plugin that fails to execute.
these are two different things - plugin execution failed and plugin executed successfully and should block the request.
guardrail is still not to be used. let's change the text here to avoid confusion for a reader

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems there was some mixed up in the tests with other PR's


server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{failingPlugin})
responseBody := []byte(`{"choices":[{"text":"some response"}]}`)
_, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody)
if err == nil {
t.Fatal("HandleResponseBody should have returned an error")
}

if got := err.Error(); got == "" {
t.Error("Expected non-empty error message")
}
}

func TestHandleResponseBody_StreamingWithPlugin(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

noopPlugin := &fakeResponsePlugin{
name: "noop",
mutateFn: func(_ context.Context, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
},
}
Comment on lines +175 to +180
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what is this text with noop plugin? why is it needed?

Copy link
Copy Markdown
Contributor Author

@abdallahsamabd abdallahsamabd Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The noop plugin is needed to bypass the len(s.responsePlugins) == 0 early return in HandleResponseBody. Without it, the test would only exercise the no-plugins shortcut path. The noop ensures we go through the full code path (JSON unmarshal → plugin execution → response construction) in streaming mode, validating it completes without errors.

"Noop" stands for no operation — it's a plugin that does nothing. It receives the headers and body, and returns them unchanged. It exists purely to satisfy the "has plugins" condition so the code doesn't take the early-return shortcut.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. maybe it would be good to reuse the mutating plugin, so we can check that plugin is actually executed correctly in streaming mode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since mutations are not yet applied to the response in this PR, we can't verify the plugin's effect from the returned response. Once we add mutation in the follow-up we'll update this test to use a mutating plugin


server := NewServer(true, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{noopPlugin})
responseBody := []byte(`{"choices":[{"text":"Hello!"}]}`)
resp, err := server.HandleResponseBody(ctx, newTestRequestContext(), responseBody)
if err != nil {
t.Fatalf("HandleResponseBody returned unexpected error: %v", err)
}

// Plugins are executed but mutations are not yet applied to the response.
want := []*extProcPb.ProcessingResponse{
{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{},
},
},
}
if diff := cmp.Diff(want, resp, protocmp.Transform()); diff != "" {
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
}
}

func TestProcessResponseBody_Streaming(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())

server := NewServer(true, &fakeDatastore{}, []framework.PayloadProcessor{}, []framework.PayloadProcessor{})

chunk1 := &extProcPb.HttpBody{
Body: []byte(`{"choices":[{"te`),
}
chunk2 := &extProcPb.HttpBody{
Body: []byte(`xt":"Hello!"}]}`),
EndOfStream: true,
}

reqCtx := newTestRequestContext()
respStreamedBody := &streamedBody{}

resp1, err := server.processResponseBody(ctx, reqCtx, chunk1, respStreamedBody)
if err != nil {
t.Fatalf("processResponseBody chunk1 returned unexpected error: %v", err)
}
if resp1 != nil {
t.Fatalf("processResponseBody chunk1 should return nil while buffering, got: %v", resp1)
}

resp2, err := server.processResponseBody(ctx, reqCtx, chunk2, respStreamedBody)
if err != nil {
t.Fatalf("processResponseBody chunk2 returned unexpected error: %v", err)
}
if resp2 == nil {
t.Fatal("processResponseBody chunk2 should return a response on EoS")
}
}
Loading