Skip to content

Commit e9b5acf

Browse files
committed
feat: enable streaming usage metrics for OpenAI-compatible providers
Inject `stream_options={"include_usage": True}` when streaming and OpenTelemetry telemetry is active. Telemetry always overrides any caller preference to ensure complete and consistent observability metrics. Changes: - Add `get_stream_options_for_telemetry()` utility in openai_compat.py - Integrate telemetry-driven stream_options injection in OpenAIMixin (benefits OpenAI, Bedrock, Runpod, vLLM, TGI, and 12+ other providers) - Integrate telemetry-driven stream_options injection in LiteLLMOpenAIMixin (benefits WatsonX and other LiteLLM-based providers) - Add `_litellm_extra_request_params()` hook for provider-specific params - Remove duplicated stream_options logic from Bedrock, Runpod, WatsonX - Comprehensive unit tests for injection behavior Fixes #3981
1 parent 7006630 commit e9b5acf

9 files changed

Lines changed: 476 additions & 130 deletions

File tree

src/llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,7 @@ async def openai_chat_completion(
8181
self,
8282
params: OpenAIChatCompletionRequestWithExtraBody,
8383
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
84-
"""Override to enable streaming usage metrics and handle authentication errors."""
85-
# Enable streaming usage metrics when telemetry is active
86-
if params.stream:
87-
if params.stream_options is None:
88-
params.stream_options = {"include_usage": True}
89-
elif "include_usage" not in params.stream_options:
90-
params.stream_options = {**params.stream_options, "include_usage": True}
91-
84+
"""Override to handle authentication errors and null responses."""
9285
try:
9386
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
9487
result = await super().openai_chat_completion(params=params)

src/llama_stack/providers/remote/inference/runpod/runpod.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from collections.abc import AsyncIterator
8-
97
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
10-
from llama_stack_api import (
11-
OpenAIChatCompletion,
12-
OpenAIChatCompletionChunk,
13-
OpenAIChatCompletionRequestWithExtraBody,
14-
)
158

169
from .config import RunpodImplConfig
1710

@@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
2922
def get_base_url(self) -> str:
3023
"""Get base URL for OpenAI client."""
3124
return str(self.config.base_url)
32-
33-
async def openai_chat_completion(
34-
self,
35-
params: OpenAIChatCompletionRequestWithExtraBody,
36-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
37-
"""Override to add RunPod-specific stream_options requirement."""
38-
params = params.model_copy()
39-
40-
if params.stream and not params.stream_options:
41-
params.stream_options = {"include_usage": True}
42-
43-
return await super().openai_chat_completion(params)

src/llama_stack/providers/remote/inference/watsonx/watsonx.py

Lines changed: 12 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
from llama_stack.log import get_logger
1414
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
1515
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
16-
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
17-
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
1816
from llama_stack_api import (
1917
Model,
2018
ModelType,
2119
OpenAIChatCompletion,
2220
OpenAIChatCompletionChunk,
2321
OpenAIChatCompletionRequestWithExtraBody,
2422
OpenAIChatCompletionUsage,
25-
OpenAICompletion,
2623
OpenAICompletionRequestWithExtraBody,
2724
OpenAIEmbeddingsRequestWithExtraBody,
2825
OpenAIEmbeddingsResponse,
@@ -48,57 +45,25 @@ def __init__(self, config: WatsonXConfig):
4845
openai_compat_api_base=self.get_base_url(),
4946
)
5047

48+
def _litellm_extra_request_params(
49+
self,
50+
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
51+
) -> dict[str, Any]:
52+
# These are watsonx-specific parameters used by LiteLLM.
53+
return {"timeout": self.config.timeout, "project_id": self.config.project_id}
54+
5155
async def openai_chat_completion(
5256
self,
5357
params: OpenAIChatCompletionRequestWithExtraBody,
5458
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
5559
"""
56-
Override parent method to add timeout and inject usage object when missing.
60+
Override parent method to inject usage object when missing.
61+
5762
This works around a LiteLLM defect where usage block is sometimes dropped.
63+
Note: request parameter construction (including telemetry-driven stream_options injection)
64+
is handled by LiteLLMOpenAIMixin via _litellm_extra_request_params().
5865
"""
59-
60-
# Add usage tracking for streaming when telemetry is active
61-
stream_options = params.stream_options
62-
if params.stream:
63-
if stream_options is None:
64-
stream_options = {"include_usage": True}
65-
elif "include_usage" not in stream_options:
66-
stream_options = {**stream_options, "include_usage": True}
67-
68-
model_obj = await self.model_store.get_model(params.model)
69-
70-
request_params = await prepare_openai_completion_params(
71-
model=self.get_litellm_model_name(model_obj.provider_resource_id),
72-
messages=params.messages,
73-
frequency_penalty=params.frequency_penalty,
74-
function_call=params.function_call,
75-
functions=params.functions,
76-
logit_bias=params.logit_bias,
77-
logprobs=params.logprobs,
78-
max_completion_tokens=params.max_completion_tokens,
79-
max_tokens=params.max_tokens,
80-
n=params.n,
81-
parallel_tool_calls=params.parallel_tool_calls,
82-
presence_penalty=params.presence_penalty,
83-
response_format=params.response_format,
84-
seed=params.seed,
85-
stop=params.stop,
86-
stream=params.stream,
87-
stream_options=stream_options,
88-
temperature=params.temperature,
89-
tool_choice=params.tool_choice,
90-
tools=params.tools,
91-
top_logprobs=params.top_logprobs,
92-
top_p=params.top_p,
93-
user=params.user,
94-
api_key=self.get_api_key(),
95-
api_base=self.api_base,
96-
# These are watsonx-specific parameters
97-
timeout=self.config.timeout,
98-
project_id=self.config.project_id,
99-
)
100-
101-
result = await litellm.acompletion(**request_params)
66+
result = await super().openai_chat_completion(params)
10267

10368
# If not streaming, check and inject usage if missing
10469
if not params.stream:
@@ -175,49 +140,6 @@ async def _normalize_stream(
175140
logger.error(f"Error normalizing stream: {e}", exc_info=True)
176141
raise
177142

178-
async def openai_completion(
179-
self,
180-
params: OpenAICompletionRequestWithExtraBody,
181-
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
182-
"""
183-
Override parent method to add watsonx-specific parameters.
184-
"""
185-
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
186-
187-
model_obj = await self.model_store.get_model(params.model)
188-
189-
request_params = await prepare_openai_completion_params(
190-
model=self.get_litellm_model_name(model_obj.provider_resource_id),
191-
prompt=params.prompt,
192-
best_of=params.best_of,
193-
echo=params.echo,
194-
frequency_penalty=params.frequency_penalty,
195-
logit_bias=params.logit_bias,
196-
logprobs=params.logprobs,
197-
max_tokens=params.max_tokens,
198-
n=params.n,
199-
presence_penalty=params.presence_penalty,
200-
seed=params.seed,
201-
stop=params.stop,
202-
stream=params.stream,
203-
stream_options=params.stream_options,
204-
temperature=params.temperature,
205-
top_p=params.top_p,
206-
user=params.user,
207-
suffix=params.suffix,
208-
api_key=self.get_api_key(),
209-
api_base=self.api_base,
210-
# These are watsonx-specific parameters
211-
timeout=self.config.timeout,
212-
project_id=self.config.project_id,
213-
)
214-
result = await litellm.atext_completion(**request_params)
215-
216-
if params.stream:
217-
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types
218-
219-
return result # type: ignore[return-value] # external lib lacks type stubs
220-
221143
async def openai_embeddings(
222144
self,
223145
params: OpenAIEmbeddingsRequestWithExtraBody,

src/llama_stack/providers/utils/inference/litellm_openai_mixin.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
import base64
88
import struct
99
from collections.abc import AsyncIterator
10+
from typing import Any
1011

1112
import litellm
1213

1314
from llama_stack.core.request_headers import NeedsRequestProviderData
1415
from llama_stack.log import get_logger
1516
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
1617
from llama_stack.providers.utils.inference.openai_compat import (
18+
get_stream_options_for_telemetry,
1719
prepare_openai_completion_params,
1820
)
1921
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
@@ -180,6 +182,9 @@ async def openai_completion(
180182
self,
181183
params: OpenAICompletionRequestWithExtraBody,
182184
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
185+
# Inject stream_options when streaming and telemetry is active
186+
stream_options = get_stream_options_for_telemetry(params.stream_options, params.stream)
187+
183188
if not self.model_store:
184189
raise ValueError("Model store is not initialized")
185190

@@ -202,13 +207,14 @@ async def openai_completion(
202207
seed=params.seed,
203208
stop=params.stop,
204209
stream=params.stream,
205-
stream_options=params.stream_options,
210+
stream_options=stream_options,
206211
temperature=params.temperature,
207212
top_p=params.top_p,
208213
user=params.user,
209214
suffix=params.suffix,
210215
api_key=self.get_api_key(),
211216
api_base=self.api_base,
217+
**self._litellm_extra_request_params(params),
212218
)
213219
# LiteLLM returns compatible type but mypy can't verify external library
214220
result = await litellm.atext_completion(**request_params)
@@ -222,14 +228,8 @@ async def openai_chat_completion(
222228
self,
223229
params: OpenAIChatCompletionRequestWithExtraBody,
224230
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
225-
# Add usage tracking for streaming when telemetry is active
226-
227-
stream_options = params.stream_options
228-
if params.stream:
229-
if stream_options is None:
230-
stream_options = {"include_usage": True}
231-
elif "include_usage" not in stream_options:
232-
stream_options = {**stream_options, "include_usage": True}
231+
# Inject stream_options when streaming and telemetry is active
232+
stream_options = get_stream_options_for_telemetry(params.stream_options, params.stream)
233233

234234
if not self.model_store:
235235
raise ValueError("Model store is not initialized")
@@ -265,6 +265,7 @@ async def openai_chat_completion(
265265
user=params.user,
266266
api_key=self.get_api_key(),
267267
api_base=self.api_base,
268+
**self._litellm_extra_request_params(params),
268269
)
269270
# LiteLLM returns compatible type but mypy can't verify external library
270271
result = await litellm.acompletion(**request_params)
@@ -288,6 +289,20 @@ async def check_model_availability(self, model: str) -> bool:
288289

289290
return model in litellm.models_by_provider[self.litellm_provider_name]
290291

292+
def _litellm_extra_request_params(
293+
self,
294+
params: OpenAIChatCompletionRequestWithExtraBody | OpenAICompletionRequestWithExtraBody,
295+
) -> dict[str, Any]:
296+
"""
297+
Provider hook for extra LiteLLM/OpenAI-compat request params.
298+
299+
This is intentionally a narrow hook so provider adapters (e.g. WatsonX)
300+
can add provider-specific kwargs (timeouts, project IDs, etc.) while the
301+
mixin remains the single source of truth for telemetry-driven
302+
stream_options injection.
303+
"""
304+
return {}
305+
291306

292307
def b64_encode_openai_embeddings_response(
293308
response_data: list[dict], encoding_format: str | None = "float"

src/llama_stack/providers/utils/inference/openai_compat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,28 @@ def prepare_openai_embeddings_params(
235235
params["user"] = user
236236

237237
return params
238+
239+
240+
def get_stream_options_for_telemetry(
241+
stream_options: dict[str, Any] | None,
242+
is_streaming: bool,
243+
) -> dict[str, Any] | None:
244+
"""
245+
Inject stream_options when streaming and telemetry is active.
246+
247+
Active telemetry takes precedence over caller preference to ensure
248+
complete and consistent observability metrics.
249+
"""
250+
if not is_streaming:
251+
return stream_options
252+
253+
from opentelemetry import trace
254+
255+
span = trace.get_current_span()
256+
if not span or not span.is_recording():
257+
return stream_options
258+
259+
if stream_options is None:
260+
return {"include_usage": True}
261+
262+
return {**stream_options, "include_usage": True}

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from llama_stack.core.request_headers import NeedsRequestProviderData
1717
from llama_stack.log import get_logger
1818
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
19-
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
19+
from llama_stack.providers.utils.inference.openai_compat import (
20+
get_stream_options_for_telemetry,
21+
prepare_openai_completion_params,
22+
)
2023
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
2124
from llama_stack_api import (
2225
Model,
@@ -270,6 +273,9 @@ async def openai_completion(
270273
"""
271274
Direct OpenAI completion API call.
272275
"""
276+
# Inject stream_options when streaming and telemetry is active
277+
stream_options = get_stream_options_for_telemetry(params.stream_options, params.stream or False)
278+
273279
provider_model_id = await self._get_provider_model_id(params.model)
274280
self._validate_model_allowed(provider_model_id)
275281

@@ -287,7 +293,7 @@ async def openai_completion(
287293
seed=params.seed,
288294
stop=params.stop,
289295
stream=params.stream,
290-
stream_options=params.stream_options,
296+
stream_options=stream_options,
291297
temperature=params.temperature,
292298
top_p=params.top_p,
293299
user=params.user,
@@ -306,6 +312,9 @@ async def openai_chat_completion(
306312
"""
307313
Direct OpenAI chat completion API call.
308314
"""
315+
# Inject stream_options when streaming and telemetry is active
316+
stream_options = get_stream_options_for_telemetry(params.stream_options, params.stream or False)
317+
309318
provider_model_id = await self._get_provider_model_id(params.model)
310319
self._validate_model_allowed(provider_model_id)
311320

@@ -346,7 +355,7 @@ async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
346355
seed=params.seed,
347356
stop=params.stop,
348357
stream=params.stream,
349-
stream_options=params.stream_options,
358+
stream_options=stream_options,
350359
temperature=params.temperature,
351360
tool_choice=params.tool_choice,
352361
tools=params.tools,

0 commit comments

Comments
 (0)