|
56 | 56 | ErrorFrame, |
57 | 57 | Frame, |
58 | 58 | InterruptionFrame, |
| 59 | + LLMFullResponseStartFrame, |
59 | 60 | StartFrame, |
60 | 61 | TTSAudioRawFrame, |
61 | 62 | TTSStartedFrame, |
@@ -653,6 +654,11 @@ def __init__( |
653 | 654 | # Track the end time of the last word in the current generation |
654 | 655 | self._generation_end_time = 0.0 |
655 | 656 |
|
| 657 | + # Context ID that was pre-opened on the server during process_frame |
| 658 | + # (LLMFullResponseStartFrame) to avoid context creation latency when |
| 659 | + # the first text arrives. |
| 660 | + self._prewarmed_context_id: Optional[str] = None |
| 661 | + |
656 | 662 | # Init-only config (not runtime-updatable). |
657 | 663 | self._audio_encoding = encoding |
658 | 664 | self._audio_sample_rate = 0 # Set in start() |
@@ -726,6 +732,29 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect |
726 | 732 | if isinstance(frame, TTSStoppedFrame): |
727 | 733 | await self.add_word_timestamps([("Reset", 0)]) |
728 | 734 |
|
| 735 | + async def process_frame(self, frame: Frame, direction: FrameDirection): |
| 736 | + """Process incoming frames and pre-open context on LLM response start. |
| 737 | +
|
| 738 | + Eagerly sends the context configuration to the server when |
| 739 | + LLMFullResponseStartFrame arrives, so the context is ready by the time |
| 740 | + text starts flowing. The base class assigns ``_turn_context_id`` before |
| 741 | + this runs, which is reused for all ``run_tts`` calls within the turn. |
| 742 | + """ |
| 743 | + await super().process_frame(frame, direction) |
| 744 | + |
| 745 | + if isinstance(frame, LLMFullResponseStartFrame): |
| 746 | + if self._prewarmed_context_id: |
| 747 | + try: |
| 748 | + await self._send_close_context(self._prewarmed_context_id) |
| 749 | + except Exception as e: |
| 750 | + logger.warning(f"{self}: Failed to close previous prewarmed context: {e}") |
| 751 | + self._prewarmed_context_id = None |
| 752 | + try: |
| 753 | + await self._send_context(self._turn_context_id) |
| 754 | + self._prewarmed_context_id = self._turn_context_id |
| 755 | + except Exception as e: |
| 756 | + logger.warning(f"{self}: Failed to pre-open context: {e}") |
| 757 | + |
729 | 758 | def _calculate_word_times(self, timestamp_info: Dict[str, Any]) -> List[Tuple[str, float]]: |
730 | 759 | """Calculate word timestamps from Inworld WebSocket API response. |
731 | 760 |
|
@@ -887,6 +916,7 @@ async def _disconnect_websocket(self): |
887 | 916 | finally: |
888 | 917 | await self.remove_active_audio_context() |
889 | 918 | self._websocket = None |
| 919 | + self._prewarmed_context_id = None |
890 | 920 | self._cumulative_time = 0.0 |
891 | 921 | self._generation_end_time = 0.0 |
892 | 922 | await self._call_event_handler("on_disconnected") |
@@ -1001,9 +1031,16 @@ async def _keepalive_task_handler(self): |
1001 | 1031 | async def _send_context(self, context_id: str): |
1002 | 1032 | """Send a context to the Inworld WebSocket TTS service. |
1003 | 1033 |
|
| 1034 | + Skips the send if this context was already pre-opened on the server |
| 1035 | + (prewarmed during process_frame). |
| 1036 | +
|
1004 | 1037 | Args: |
1005 | 1038 | context_id: The context ID. |
1006 | 1039 | """ |
| 1040 | + if context_id == self._prewarmed_context_id: |
| 1041 | + self._prewarmed_context_id = None |
| 1042 | + return |
| 1043 | + |
1007 | 1044 | audio_config = { |
1008 | 1045 | "audioEncoding": self._audio_encoding, |
1009 | 1046 | "sampleRateHertz": self._audio_sample_rate, |
|
0 commit comments