@@ -22,7 +22,10 @@ import (
2222 "strings"
2323
2424 configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
25+ filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
2526 extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27+ "github.com/go-logr/logr"
28+
2629 "sigs.k8s.io/controller-runtime/pkg/log"
2730
2831 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -60,7 +63,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6063 // will add the processing for streaming case.
6164 reqCtx .ResponseComplete = true
6265
63- reqCtx .respBodyResp = generateResponseBodyResponses (responseBytes , true )
66+ reqCtx .respBodyResp = generateResponseBodyResponses (responseBytes , true , reqCtx , logger )
6467 return reqCtx , nil
6568}
6669
@@ -75,12 +78,11 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
7578 s .director .HandleResponseBodyChunk (ctx , reqCtx )
7679}
7780
78-
7981// The function is to handle streaming response if the modelServer is streaming.
8082func (s * StreamingServer ) HandleResponseTrailers (
8183 ctx context.Context ,
8284 reqCtx * RequestContext ,
83- ) (* RequestContext , error ) {
85+ ) (* RequestContext , error ) {
8486
8587 return s .director .HandleResponseTrailers (ctx , reqCtx )
8688}
@@ -110,6 +112,9 @@ func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext)
110112 },
111113 },
112114 },
115+ ModeOverride : & filterPb.ProcessingMode {
116+ ResponseTrailerMode : filterPb .ProcessingMode_SEND ,
117+ },
113118 }
114119}
115120
@@ -118,29 +123,95 @@ func (s *StreamingServer) generateResponseTrailerResponse(reqCtx *RequestContext
118123 return & extProcPb.ProcessingResponse {
119124 Response : & extProcPb.ProcessingResponse_ResponseTrailers {
120125 ResponseTrailers : & extProcPb.TrailersResponse {
121- HeaderMutation : & extProcPb.HeaderMutation {
122- // Correct field or remove if unnecessary
123- SetHeaders : s .generateResponseTrailers (reqCtx ),
124- },
126+ HeaderMutation : & extProcPb.HeaderMutation {
127+ // Correct field or remove if unnecessary
128+ SetHeaders : s .generateResponseTrailers (reqCtx ),
125129 },
126130 },
127- }
131+ },
128132 }
133+ }
134+
135+ func generateResponseBodyResponses (
136+ responseBodyBytes []byte ,
137+ setEoS bool ,
138+ reqCtx * RequestContext ,
139+ logger logr.Logger ,
140+ ) []* extProcPb.ProcessingResponse {
141+ if reqCtx != nil && reqCtx .ModelServerStreaming {
142+
143+ raw := string (responseBodyBytes )
144+ events := strings .Split (raw , "\n \n " )
129145
130- func generateResponseBodyResponses (responseBodyBytes []byte , setEoS bool ) []* extProcPb.ProcessingResponse {
131- commonResponses := buildCommonResponses (responseBodyBytes , bodyByteLimit , setEoS )
132- responses := []* extProcPb.ProcessingResponse {}
133- for _ , commonResp := range commonResponses {
134- resp := & extProcPb.ProcessingResponse {
135- Response : & extProcPb.ProcessingResponse_ResponseBody {
136- ResponseBody : & extProcPb.BodyResponse {
137- Response : commonResp ,
146+ var rebuilt strings.Builder
147+ for _ , ev := range events {
148+ if ! strings .HasPrefix (ev , "data: " ) {
149+ continue
150+ }
151+ payload := strings .TrimPrefix (ev , "data: " )
152+ if payload == "[DONE]" {
153+ rebuilt .WriteString ("data: [DONE]\n \n " )
154+ continue
155+ }
156+
157+ // Try to unmarshal only the JSON
158+ var obj map [string ]interface {}
159+ if err := json .Unmarshal ([]byte (payload ), & obj ); err != nil {
160+ logger .Error (err , "failed to unmarshal SSE payload" , "payload" , payload )
161+ } else {
162+ if usage , ok := obj ["usage" ].(map [string ]interface {}); ok && usage != nil {
163+ usage ["ttft_ms" ] = reqCtx .TTFT
164+ usage ["predicted_ttft_ms" ] = reqCtx .PredictedTTFT
165+ usage ["tpot_observations_ms" ] = reqCtx .TPOTObservations
166+ usage ["predicted_tpot_observations_ms" ] = reqCtx .PredictedTPOTObservations
167+ usage ["avg_tpot_ms" ] = reqCtx .AvgTPOT
168+ usage ["avg_predicted_tpot_ms" ] = reqCtx .AvgPredictedTPOT
169+ }
170+ if mod , err := json .Marshal (obj ); err != nil {
171+ logger .Error (err , "failed to re-marshal modified JSON" , "obj" , obj )
172+ } else {
173+ payload = string (mod )
174+ }
175+ }
176+
177+ // Re-attach SSE prefix
178+ rebuilt .WriteString ("data: " )
179+ rebuilt .WriteString (payload )
180+ rebuilt .WriteString ("\n \n " )
181+ }
182+
183+ // Feed into your existing chunker
184+ modified := []byte (rebuilt .String ())
185+ commonResponses := buildCommonResponses (modified , bodyByteLimit , setEoS )
186+
187+ // Wrap as ProcessingResponses
188+ out := make ([]* extProcPb.ProcessingResponse , 0 , len (commonResponses ))
189+ for _ , cr := range commonResponses {
190+ out = append (out , & extProcPb.ProcessingResponse {
191+ Response : & extProcPb.ProcessingResponse_ResponseBody {
192+ ResponseBody : & extProcPb.BodyResponse {
193+ Response : cr ,
194+ },
138195 },
139- },
196+ })
140197 }
141- responses = append (responses , resp )
198+ return out
199+ } else {
200+ commonResponses := buildCommonResponses (responseBodyBytes , bodyByteLimit , setEoS )
201+ responses := []* extProcPb.ProcessingResponse {}
202+ for _ , commonResp := range commonResponses {
203+ resp := & extProcPb.ProcessingResponse {
204+ Response : & extProcPb.ProcessingResponse_ResponseBody {
205+ ResponseBody : & extProcPb.BodyResponse {
206+ Response : commonResp ,
207+ },
208+ },
209+ }
210+ responses = append (responses , resp )
211+ }
212+ return responses
142213 }
143- return responses
214+
144215}
145216
146217func (s * StreamingServer ) generateResponseHeaders (reqCtx * RequestContext ) []* configPb.HeaderValueOption {
@@ -180,7 +251,7 @@ func (s *StreamingServer) generateResponseTrailers(reqCtx *RequestContext) []*co
180251 }
181252
182253 // include all headers
183- for key , value := range reqCtx .Response .Trailers {
254+ for key , value := range reqCtx .Response .Trailers {
184255 trailers = append (trailers , & configPb.HeaderValueOption {
185256 Header : & configPb.HeaderValue {
186257 Key : key ,
0 commit comments