Skip to content

Commit d1e3916

Browse files
committed
feat(azure-stt): add cancellation tracing and session guards
1 parent cae27a5 commit d1e3916

4 files changed

Lines changed: 234 additions & 35 deletions

File tree

changelog/3884.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Added Azure STT cancellation tracing attributes and session termination guards so canceled recognition sessions surface structured observability data and stop accepting audio as if still healthy.

src/pipecat/services/azure/stt.py

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pipecat.services.stt_service import STTService
3232
from pipecat.transcriptions.language import Language
3333
from pipecat.utils.time import time_now_iso8601
34-
from pipecat.utils.tracing.service_decorators import traced_stt
34+
from pipecat.utils.tracing.service_decorators import trace_stt_cancellation, traced_stt
3535

3636
try:
3737
from azure.cognitiveservices.speech import (
@@ -155,6 +155,10 @@ def __init__(
155155

156156
self._audio_stream = None
157157
self._speech_recognizer = None
158+
self._audio_sent = False
159+
self._recognition_active = False
160+
self._recognition_terminated = False
161+
self._shutdown_requested = False
158162

159163
def can_generate_metrics(self) -> bool:
160164
"""Check if this service can generate performance metrics.
@@ -204,7 +208,12 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
204208
try:
205209
await self.start_processing_metrics()
206210
if self._audio_stream:
211+
if self._recognition_terminated and not self._shutdown_requested:
212+
logger.warning("Azure STT recognition terminated, dropping audio chunk")
213+
yield None
214+
return
207215
self._audio_stream.write(audio)
216+
self._audio_sent = True
208217
yield None
209218
except Exception as e:
210219
yield ErrorFrame(error=f"Unknown error occurred: {e}")
@@ -241,6 +250,11 @@ async def _connect(self):
241250
if self._audio_stream:
242251
return
243252

253+
self._audio_sent = False
254+
self._recognition_active = False
255+
self._recognition_terminated = False
256+
self._shutdown_requested = False
257+
244258
try:
245259
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
246260
self._audio_stream = PushAudioInputStream(stream_format)
@@ -263,6 +277,36 @@ async def _connect(self):
263277
error_msg=f"Uncaught exception during initialization: {e}", exception=e
264278
)
265279

280+
async def stop(self, frame: EndFrame):
281+
"""Stop the speech recognition service.
282+
283+
Cleanly shuts down the Azure speech recognizer and closes audio streams.
284+
285+
Args:
286+
frame: Frame indicating the end of processing.
287+
"""
288+
await super().stop(frame)
289+
290+
self._shutdown_requested = True
291+
self._recognition_active = False
292+
self._recognition_terminated = True
293+
await self._disconnect()
294+
295+
async def cancel(self, frame: CancelFrame):
296+
"""Cancel the speech recognition service.
297+
298+
Immediately stops recognition and closes resources.
299+
300+
Args:
301+
frame: Frame indicating cancellation.
302+
"""
303+
await super().cancel(frame)
304+
305+
self._shutdown_requested = True
306+
self._recognition_active = False
307+
self._recognition_terminated = True
308+
await self._disconnect()
309+
266310
async def _disconnect(self):
267311
"""Stop recognition and close audio streams."""
268312
if self._speech_recognizer:
@@ -280,6 +324,25 @@ async def _handle_transcription(
280324
"""Handle a transcription result with tracing."""
281325
await self.stop_processing_metrics()
282326

327+
async def _trace_cancellation(
328+
self,
329+
*,
330+
reason: str,
331+
code: str,
332+
recoverable: bool,
333+
phase: str,
334+
):
335+
"""Record a trace span for a canceled Azure STT recognition."""
336+
trace_stt_cancellation(
337+
self,
338+
error_type="azure.stt.canceled",
339+
cancel_reason=reason,
340+
cancel_code=code,
341+
recoverable=recoverable,
342+
phase=phase,
343+
region=self._settings.region if isinstance(self._settings.region, str) else None,
344+
)
345+
283346
def _on_handle_recognized(self, event):
284347
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
285348
language = getattr(event.result, "language", None) or self._settings.language
@@ -309,30 +372,87 @@ def _on_handle_recognizing(self, event):
309372

310373
def _on_handle_canceled(self, event):
311374
details = getattr(event, "cancellation_details", None)
312-
reason = getattr(details, "reason", "UNKNOWN")
313-
code = getattr(details, "code", "UNKNOWN")
375+
reason = self._normalize_cancellation_value(getattr(details, "reason", "UNKNOWN"))
376+
code = self._normalize_cancellation_value(getattr(details, "code", "UNKNOWN"))
314377
error_details = getattr(details, "error_details", "")
378+
phase = self._get_cancellation_phase()
379+
recoverable = self._is_cancellation_recoverable(reason, code)
380+
381+
self._recognition_active = False
382+
self._recognition_terminated = True
315383

316384
logger.error(
317-
"Azure STT recognition canceled: reason={}, code={}, details={}",
385+
"Azure STT recognition canceled: reason={}, code={}, phase={}, recoverable={}, details={}",
318386
reason,
319387
code,
388+
phase,
389+
recoverable,
320390
error_details,
321391
)
322392

323-
error_message = f"Azure STT recognition canceled: {code} - {error_details}"
393+
asyncio.run_coroutine_threadsafe(
394+
self._trace_cancellation(
395+
reason=reason,
396+
code=code,
397+
recoverable=recoverable,
398+
phase=phase,
399+
),
400+
self.get_event_loop(),
401+
)
402+
403+
error_message = f"Azure STT recognition canceled: {reason} ({code})"
324404
asyncio.run_coroutine_threadsafe(
325405
self.push_error(error_msg=error_message), self.get_event_loop()
326406
)
327407

328408
def _on_handle_session_started(self, event):
409+
self._recognition_active = True
410+
self._recognition_terminated = False
329411
logger.info(
330412
"Azure STT session started: session_id={}",
331413
getattr(event, "session_id", "unknown"),
332414
)
333415

334416
def _on_handle_session_stopped(self, event):
335-
logger.warning(
336-
"Azure STT session stopped: session_id={}",
337-
getattr(event, "session_id", "unknown"),
338-
)
417+
self._recognition_active = False
418+
self._recognition_terminated = True
419+
if self._shutdown_requested:
420+
logger.info(
421+
"Azure STT session stopped during shutdown: session_id={}",
422+
getattr(event, "session_id", "unknown"),
423+
)
424+
else:
425+
logger.warning(
426+
"Azure STT session stopped: session_id={}",
427+
getattr(event, "session_id", "unknown"),
428+
)
429+
430+
@staticmethod
431+
def _normalize_cancellation_value(value: Any) -> str:
432+
normalized = getattr(value, "name", None)
433+
if normalized:
434+
return normalized
435+
return str(value)
436+
437+
def _get_cancellation_phase(self) -> str:
438+
if self._shutdown_requested:
439+
return "shutdown"
440+
if not self._recognition_active and not self._audio_sent:
441+
return "startup"
442+
return "streaming"
443+
444+
@staticmethod
445+
def _is_cancellation_recoverable(reason: str, code: str) -> bool:
446+
if reason == "CancelledByUser":
447+
return True
448+
if reason != "Error":
449+
return False
450+
451+
return code in {
452+
"ConnectionFailure",
453+
"ServiceRedirectPermanent",
454+
"ServiceRedirectTemporary",
455+
"ServiceTimeout",
456+
"ServiceUnavailable",
457+
"TooManyRequests",
458+
}

src/pipecat/utils/tracing/service_decorators.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,51 @@ async def wrapper(self, transcript, is_final, language=None):
373373
return decorator
374374

375375

376+
def trace_stt_cancellation(
377+
service,
378+
*,
379+
error_type: str,
380+
cancel_reason: str,
381+
cancel_code: str,
382+
recoverable: bool,
383+
phase: str,
384+
region: Optional[str] = None,
385+
) -> None:
386+
"""Create a trace span for STT cancellation events.
387+
388+
Args:
389+
service: STT service instance generating the cancellation.
390+
error_type: Stable error classification.
391+
cancel_reason: Provider cancellation reason.
392+
cancel_code: Provider cancellation code.
393+
recoverable: Whether the application should attempt recovery.
394+
phase: Service lifecycle phase where cancellation happened.
395+
region: Cloud region associated with the service, if known.
396+
"""
397+
if not is_tracing_available() or not getattr(service, "_tracing_enabled", False):
398+
return
399+
400+
service_class_name = service.__class__.__name__
401+
parent_context = _get_turn_context(service) or _get_parent_service_context(service)
402+
403+
tracer = trace.get_tracer("pipecat")
404+
with tracer.start_as_current_span("stt.cancel", context=parent_context) as current_span:
405+
current_span.set_attribute(
406+
"gen_ai.system", service_class_name.replace("STTService", "").lower()
407+
)
408+
current_span.set_attribute("gen_ai.operation.name", "stt.cancel")
409+
current_span.set_attribute("error.type", error_type)
410+
current_span.set_attribute("stt.cancel.reason", cancel_reason)
411+
current_span.set_attribute("stt.cancel.code", cancel_code)
412+
current_span.set_attribute("stt.cancel.recoverable", recoverable)
413+
current_span.set_attribute("stt.cancel.phase", phase)
414+
if region:
415+
current_span.set_attribute("cloud.region", region)
416+
417+
if cancel_reason == "Error":
418+
current_span.set_status(trace.Status(trace.StatusCode.ERROR, cancel_code))
419+
420+
376421
def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable:
377422
"""Trace LLM service methods with LLM-specific attributes.
378423

0 commit comments

Comments
 (0)