1111import tempfile
1212import threading
1313import time
14- import weakref
1514from typing import Any , Awaitable , Callable , Dict , Iterable , List , Optional , Sequence , TypedDict , Union , cast
1615
1716import litellm
@@ -162,13 +161,10 @@ class LightningSpanExporter(SpanExporter):
162161 * Buffer access is protected by a re-entrant lock.
163162 * Export is synchronous to the caller yet schedules an async flush on the
164163 internal loop, then waits for completion.
165-
166- Args:
167- store: Optional explicit LightningStore. If None, uses `get_global_store()`.
168164 """
169165
170- def __init__ (self , store : LightningStore ):
171- self ._store = store
166+ def __init__ (self , _store : Optional [ LightningStore ] = None ):
167+ self ._store : Optional [ LightningStore ] = _store # this is only for testing purposes
172168 self ._buffer : List [ReadableSpan ] = []
173169 self ._lock : Optional [threading .RLock ] = None
174170
@@ -178,9 +174,6 @@ def __init__(self, store: LightningStore):
178174 self ._loop : Optional [asyncio .AbstractEventLoop ] = None
179175 self ._loop_thread : Optional [threading .Thread ] = None
180176
181- def set_store (self , store : LightningStore ) -> None :
182- self ._store = store
183-
184177 def _ensure_loop (self ) -> asyncio .AbstractEventLoop :
185178 """Lazily initialize the event loop and thread on first use.
186179
@@ -286,6 +279,11 @@ async def _maybe_flush(self):
286279 if not subtree_spans :
287280 continue
288281
282+ store = self ._store or get_active_llm_proxy ().get_store ()
283+ if store is None :
284+ logger .warning ("Store is not set in LLMProxy. Cannot log spans to store." )
285+ continue
286+
289287 # Merge all custom headers found in the subtree.
290288 headers_merged : Dict [str , Any ] = {}
291289
@@ -340,7 +338,7 @@ async def _maybe_flush(self):
340338
341339 # Persist each span in the subtree with the resolved identifiers.
342340 for span in subtree_spans :
343- await self . _store .add_otel_span (
341+ await store .add_otel_span (
344342 rollout_id = rollout_id , attempt_id = attempt_id , sequence_id = sequence_id_decimal , readable_span = span
345343 )
346344
@@ -410,14 +408,10 @@ class LightningOpenTelemetry(OpenTelemetry):
410408 * Ensures each request is annotated with a per-attempt sequence id so spans
411409 are ordered deterministically even with clock skew across nodes.
412410 * Uses [`LightningSpanExporter`][agentlightning.llm_proxy.LightningSpanExporter] to persist spans for analytics and training.
413-
414- Args:
415- store: Optional explicit LightningStore for the exporter.
416411 """
417412
418- def __init__ (self , store : LightningStore , llm_proxy : LLMProxy ):
419- self .llm_proxy = weakref .ref (llm_proxy )
420- self .exporter = LightningSpanExporter (store )
413+ def __init__ (self ):
414+ self .exporter = LightningSpanExporter ()
421415 config = OpenTelemetryConfig (exporter = self .exporter )
422416
423417 # Check for tracer initialization
@@ -429,18 +423,6 @@ def __init__(self, store: LightningStore, llm_proxy: LLMProxy):
429423
430424 super ().__init__ (config = config ) # pyright: ignore[reportUnknownMemberType]
431425
432- def set_store (self , store : LightningStore ) -> None :
433- self .exporter .set_store (store )
434-
435- def owned_by (self , llm_proxy : LLMProxy ) -> bool :
436- """Check whether the `llm_proxy` is the one associated with this tracer.
437-
438- Args:
439- llm_proxy: The current LLMProxy instance.
440- """
441- current_proxy = self .llm_proxy ()
442- return current_proxy is not None and current_proxy is llm_proxy
443-
444426
445427class RolloutAttemptMiddleware (BaseHTTPMiddleware ):
446428 """
@@ -450,13 +432,6 @@ class RolloutAttemptMiddleware(BaseHTTPMiddleware):
450432 LLMProxy can update store later without rebuilding middleware.
451433 """
452434
453- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
454- store = cast (weakref .ReferenceType [LightningStore ], kwargs .pop ("store" ))
455- llm_proxy = cast (weakref .ReferenceType [LLMProxy ], kwargs .pop ("llm_proxy" ))
456- self .store = store
457- self .llm_proxy = llm_proxy
458- super ().__init__ (* args , ** kwargs )
459-
460435 async def dispatch (self , request : Request , call_next : Callable [[Request ], Awaitable [Response ]]) -> Response :
461436 # Decode rollout and attempt from the URL prefix. Example:
462437 # /rollout/r123/attempt/a456/v1/chat/completions
@@ -475,7 +450,7 @@ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaita
475450 request .scope ["path" ] = new_path
476451 request .scope ["raw_path" ] = new_path .encode ()
477452
478- store = self . store ()
453+ store = get_active_llm_proxy (). get_store ()
479454 if store is not None :
480455 # Allocate a monotonic sequence id per (rollout, attempt).
481456 sequence_id = await store .get_next_span_sequence_id (rollout_id , attempt_id )
@@ -551,6 +526,14 @@ def __init__(
551526 self ._ready_event = threading .Event ()
552527 self ._callbacks_initialized_copy : Optional [List [Any ]] = None
553528
529+ def get_store (self ) -> Optional [LightningStore ]:
530+ """Get the store used by the proxy.
531+
532+ Returns:
533+ The store used by the proxy.
534+ """
535+ return self .store
536+
554537 def set_store (self , store : LightningStore ) -> None :
555538 """Set the store for the proxy.
556539
@@ -615,85 +598,32 @@ def initialize(self):
615598 if self .store is None :
616599 raise ValueError ("Store is not set. Please set the store before initializing the LLMProxy." )
617600
601+ if _global_llm_proxy is not None :
602+ logger .warning ("A global LLMProxy is already set. Overwriting it with the new instance." )
603+
604+ # Set the global LLMProxy reference for middleware/exporter access.
605+ set_active_llm_proxy (self )
606+
618607 # Install middleware if it's not already installed.
619608 installed : bool = False
620609 for mw in app .user_middleware :
621610 if mw .cls is RolloutAttemptMiddleware :
622- # Check whether the middleware is installed by myself
623- llm_proxy_ref = mw .kwargs .get ("llm_proxy" )
624- if llm_proxy_ref is None :
625- logger .warning (
626- "Found existing RolloutAttemptMiddleware without llm_proxy reference. "
627- "We will reuse this middleware and modify its proxy reference and store."
628- )
629- mw .kwargs ["llm_proxy" ] = weakref .ref (self )
630- mw .kwargs ["store" ] = weakref .ref (self .store )
631- installed = True
632- logger .info ("Updated llm_proxy and store references in existing RolloutAttemptMiddleware." )
633- elif llm_proxy_ref () is self : # type: ignore
634- # Upgrade the store if needed
635- logger .info (
636- "Found existing RolloutAttemptMiddleware owned by this LLMProxy instance. Updating store reference."
637- )
638- mw .kwargs ["store" ] = weakref .ref (self .store )
639- logger .info ("Updated store reference in existing RolloutAttemptMiddleware." )
640- installed = True
641- else :
642- # NOTE: We need to do this because middleware cannot be added once the application has started
643- # So we have to rewrite the existing one.
644- logger .error (
645- "Found existing RolloutAttemptMiddleware not owned by this LLMProxy instance. "
646- "We are going to rewrite the middleware. This may have unintended consequences to other LLMProxy instances."
647- )
648- mw .kwargs ["store" ] = weakref .ref (self .store )
649- mw .kwargs ["llm_proxy" ] = weakref .ref (self )
650- installed = True
651- logger .info (
652- "llm_proxy and store references in existing RolloutAttemptMiddleware have been rewritten."
653- )
611+ # Check whether the middleware is installed.
612+ # It could be installed by other LLM Proxy instances, but it doesn't matter.
613+ logger .info ("Found existing RolloutAttemptMiddleware installed. Will not install a new one." )
614+ installed = True
615+ break
654616
655617 if not installed :
656618 # Fallback to adding a new middleware
657619 logger .info ("Adding a new middleware to the FastAPI app." )
658- app .add_middleware (
659- RolloutAttemptMiddleware ,
660- store = weakref .ref (self .store ),
661- llm_proxy = weakref .ref (self ),
662- )
663-
664- # Register callbacks once on the global LiteLLM callback list.
665- if self ._callbacks_initialized_copy is None :
666- logger .info ("Callbacks are not initialized. Initializing them." )
667- self ._callbacks_initialized_copy = cast (List [Any ], litellm .callbacks ) + [ # type: ignore
668- AddReturnTokenIds (),
669- LightningOpenTelemetry (self .store , self ),
670- ]
620+ app .add_middleware (RolloutAttemptMiddleware )
671621
672- else :
673- logger .warning ("Callbacks are already initialized. Augmenting the initialized copy with latest store." )
674-
675- opentelemetry_callbacks = [cb for cb in litellm .callbacks if isinstance (cb , LightningOpenTelemetry ) and cb .owned_by (self )] # type: ignore
676- if len (opentelemetry_callbacks ) > 1 :
677- raise RuntimeError (
678- "Found multiple LightningOpenTelemetry callbacks for this LLMProxy instance. This is unsupported."
679- )
680- elif len (opentelemetry_callbacks ) == 0 :
681- logger .error (
682- "LightningOpenTelemetry callback not found in litellm.callbacks but the proxy has been initialized. This should not happen."
683- )
684- self ._callbacks_initialized_copy .append (LightningOpenTelemetry (self .store , self ))
685- else :
686- opentelemetry_callbacks [0 ].set_store (self .store )
687- # The updated callback should also be reflected in _callbacks_initialized_copy
688-
689- # Hacks to avoid issues on restart within the same process.
690- _reset_litellm_logging_callback_manager ()
691- # Reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
622+ if not initialize_llm_callbacks ():
623+ # If it's not the first time to initialize the callbacks, also
624+ # reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
692625 _reset_litellm_logging_worker ()
693626
694- litellm .callbacks .clear () # type: ignore
695- litellm .callbacks .extend (self ._callbacks_initialized_copy ) # type: ignore
696-
697627 def start (self ):
698628 """Start the proxy server thread and initialize global wiring.
699629
@@ -711,7 +641,7 @@ def start(self):
711641 if not self .store :
712642 raise ValueError ("Store is not set. Please set the store before starting the LLMProxy." )
713643
714- # Initialize global middleware and callbacks once .
644+ # Initialize global middleware and callbacks.
715645 self .initialize ()
716646
717647 # Persist a temp worker config for LiteLLM and point the proxy at it.
@@ -848,6 +778,60 @@ def as_resource(
848778 raise ValueError ("Either rollout_id and attempt_id must be provided, or neither." )
849779
850780
781+ _global_llm_proxy : Optional [LLMProxy ] = None
782+ _callbacks_before_litellm_start : Optional [List [Any ]] = None
783+
784+
785+ def get_active_llm_proxy () -> LLMProxy :
786+ """Get the current global LLMProxy instance.
787+
788+ Returns:
789+ Optional[LLMProxy]: The current LLMProxy if set, else None.
790+ """
791+ if _global_llm_proxy is None :
792+ raise ValueError ("Global LLMProxy is not set. Please call llm_proxy.start() first." )
793+ return _global_llm_proxy
794+
795+
796+ def set_active_llm_proxy (proxy : LLMProxy ) -> None :
797+ """Set the current global LLMProxy instance.
798+
799+ Args:
800+ proxy: The LLMProxy instance to set as global.
801+ """
802+ global _global_llm_proxy
803+ _global_llm_proxy = proxy
804+
805+
806+ def initialize_llm_callbacks () -> bool :
807+ """Restore `litellm.callbacks` to a state that is just initialized by agent-lightning.
808+
809+ When litellm is restarted multiple times in the same process, more and more callbacks
810+ will be appended to `litellm.callbacks`, which may exceed the MAX_CALLBACKS limit.
811+ This function remembers the initial state of `litellm.callbacks` and always restore to that state.
812+
813+ Returns:
814+ Whether the callbacks are initialized for the first time.
815+ """
816+ global _callbacks_before_litellm_start
817+
818+ if _callbacks_before_litellm_start is None :
819+ litellm .callbacks .extend ( # type: ignore
820+ [
821+ AddReturnTokenIds (),
822+ LightningOpenTelemetry (),
823+ ]
824+ )
825+ _callbacks_before_litellm_start = [* litellm .callbacks ] # type: ignore
826+ return True
827+
828+ _reset_litellm_logging_callback_manager ()
829+
830+ litellm .callbacks .clear () # type: ignore
831+ litellm .callbacks .extend (_callbacks_before_litellm_start ) # type: ignore
832+ return False
833+
834+
851835def _get_default_ipv4_address () -> str :
852836 """Determine the default outbound IPv4 address for this machine.
853837
0 commit comments