Skip to content

Commit 57f8753

Browse files
committed
fix llm_proxy
1 parent 5caae60 commit 57f8753

1 file changed

Lines changed: 89 additions & 105 deletions

File tree

agentlightning/llm_proxy.py

Lines changed: 89 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import tempfile
1212
import threading
1313
import time
14-
import weakref
1514
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence, TypedDict, Union, cast
1615

1716
import litellm
@@ -162,13 +161,10 @@ class LightningSpanExporter(SpanExporter):
162161
* Buffer access is protected by a re-entrant lock.
163162
* Export is synchronous to the caller yet schedules an async flush on the
164163
internal loop, then waits for completion.
165-
166-
Args:
167-
store: Optional explicit LightningStore. If None, uses `get_global_store()`.
168164
"""
169165

170-
def __init__(self, store: LightningStore):
171-
self._store = store
166+
def __init__(self, _store: Optional[LightningStore] = None):
167+
self._store: Optional[LightningStore] = _store # this is only for testing purposes
172168
self._buffer: List[ReadableSpan] = []
173169
self._lock: Optional[threading.RLock] = None
174170

@@ -178,9 +174,6 @@ def __init__(self, store: LightningStore):
178174
self._loop: Optional[asyncio.AbstractEventLoop] = None
179175
self._loop_thread: Optional[threading.Thread] = None
180176

181-
def set_store(self, store: LightningStore) -> None:
182-
self._store = store
183-
184177
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
185178
"""Lazily initialize the event loop and thread on first use.
186179
@@ -286,6 +279,11 @@ async def _maybe_flush(self):
286279
if not subtree_spans:
287280
continue
288281

282+
store = self._store or get_active_llm_proxy().get_store()
283+
if store is None:
284+
logger.warning("Store is not set in LLMProxy. Cannot log spans to store.")
285+
continue
286+
289287
# Merge all custom headers found in the subtree.
290288
headers_merged: Dict[str, Any] = {}
291289

@@ -340,7 +338,7 @@ async def _maybe_flush(self):
340338

341339
# Persist each span in the subtree with the resolved identifiers.
342340
for span in subtree_spans:
343-
await self._store.add_otel_span(
341+
await store.add_otel_span(
344342
rollout_id=rollout_id, attempt_id=attempt_id, sequence_id=sequence_id_decimal, readable_span=span
345343
)
346344

@@ -410,14 +408,10 @@ class LightningOpenTelemetry(OpenTelemetry):
410408
* Ensures each request is annotated with a per-attempt sequence id so spans
411409
are ordered deterministically even with clock skew across nodes.
412410
* Uses [`LightningSpanExporter`][agentlightning.llm_proxy.LightningSpanExporter] to persist spans for analytics and training.
413-
414-
Args:
415-
store: Optional explicit LightningStore for the exporter.
416411
"""
417412

418-
def __init__(self, store: LightningStore, llm_proxy: LLMProxy):
419-
self.llm_proxy = weakref.ref(llm_proxy)
420-
self.exporter = LightningSpanExporter(store)
413+
def __init__(self):
414+
self.exporter = LightningSpanExporter()
421415
config = OpenTelemetryConfig(exporter=self.exporter)
422416

423417
# Check for tracer initialization
@@ -429,18 +423,6 @@ def __init__(self, store: LightningStore, llm_proxy: LLMProxy):
429423

430424
super().__init__(config=config) # pyright: ignore[reportUnknownMemberType]
431425

432-
def set_store(self, store: LightningStore) -> None:
433-
self.exporter.set_store(store)
434-
435-
def owned_by(self, llm_proxy: LLMProxy) -> bool:
436-
"""Check whether the `llm_proxy` is the one associated with this tracer.
437-
438-
Args:
439-
llm_proxy: The current LLMProxy instance.
440-
"""
441-
current_proxy = self.llm_proxy()
442-
return current_proxy is not None and current_proxy is llm_proxy
443-
444426

445427
class RolloutAttemptMiddleware(BaseHTTPMiddleware):
446428
"""
@@ -450,13 +432,6 @@ class RolloutAttemptMiddleware(BaseHTTPMiddleware):
450432
LLMProxy can update store later without rebuilding middleware.
451433
"""
452434

453-
def __init__(self, *args: Any, **kwargs: Any) -> None:
454-
store = cast(weakref.ReferenceType[LightningStore], kwargs.pop("store"))
455-
llm_proxy = cast(weakref.ReferenceType[LLMProxy], kwargs.pop("llm_proxy"))
456-
self.store = store
457-
self.llm_proxy = llm_proxy
458-
super().__init__(*args, **kwargs)
459-
460435
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
461436
# Decode rollout and attempt from the URL prefix. Example:
462437
# /rollout/r123/attempt/a456/v1/chat/completions
@@ -475,7 +450,7 @@ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaita
475450
request.scope["path"] = new_path
476451
request.scope["raw_path"] = new_path.encode()
477452

478-
store = self.store()
453+
store = get_active_llm_proxy().get_store()
479454
if store is not None:
480455
# Allocate a monotonic sequence id per (rollout, attempt).
481456
sequence_id = await store.get_next_span_sequence_id(rollout_id, attempt_id)
@@ -551,6 +526,14 @@ def __init__(
551526
self._ready_event = threading.Event()
552527
self._callbacks_initialized_copy: Optional[List[Any]] = None
553528

529+
def get_store(self) -> Optional[LightningStore]:
530+
"""Get the store used by the proxy.
531+
532+
Returns:
533+
The store used by the proxy.
534+
"""
535+
return self.store
536+
554537
def set_store(self, store: LightningStore) -> None:
555538
"""Set the store for the proxy.
556539
@@ -615,85 +598,32 @@ def initialize(self):
615598
if self.store is None:
616599
raise ValueError("Store is not set. Please set the store before initializing the LLMProxy.")
617600

601+
if _global_llm_proxy is not None:
602+
logger.warning("A global LLMProxy is already set. Overwriting it with the new instance.")
603+
604+
# Set the global LLMProxy reference for middleware/exporter access.
605+
set_active_llm_proxy(self)
606+
618607
# Install middleware if it's not already installed.
619608
installed: bool = False
620609
for mw in app.user_middleware:
621610
if mw.cls is RolloutAttemptMiddleware:
622-
# Check whether the middleware is installed by myself
623-
llm_proxy_ref = mw.kwargs.get("llm_proxy")
624-
if llm_proxy_ref is None:
625-
logger.warning(
626-
"Found existing RolloutAttemptMiddleware without llm_proxy reference. "
627-
"We will reuse this middleware and modify its proxy reference and store."
628-
)
629-
mw.kwargs["llm_proxy"] = weakref.ref(self)
630-
mw.kwargs["store"] = weakref.ref(self.store)
631-
installed = True
632-
logger.info("Updated llm_proxy and store references in existing RolloutAttemptMiddleware.")
633-
elif llm_proxy_ref() is self: # type: ignore
634-
# Upgrade the store if needed
635-
logger.info(
636-
"Found existing RolloutAttemptMiddleware owned by this LLMProxy instance. Updating store reference."
637-
)
638-
mw.kwargs["store"] = weakref.ref(self.store)
639-
logger.info("Updated store reference in existing RolloutAttemptMiddleware.")
640-
installed = True
641-
else:
642-
# NOTE: We need to do this because middleware cannot be added once the application has started
643-
# So we have to rewrite the existing one.
644-
logger.error(
645-
"Found existing RolloutAttemptMiddleware not owned by this LLMProxy instance. "
646-
"We are going to rewrite the middleware. This may have unintended consequences to other LLMProxy instances."
647-
)
648-
mw.kwargs["store"] = weakref.ref(self.store)
649-
mw.kwargs["llm_proxy"] = weakref.ref(self)
650-
installed = True
651-
logger.info(
652-
"llm_proxy and store references in existing RolloutAttemptMiddleware have been rewritten."
653-
)
611+
# Check whether the middleware is installed.
612+
# It could be installed by other LLM Proxy instances, but it doesn't matter.
613+
logger.info("Found existing RolloutAttemptMiddleware installed. Will not install a new one.")
614+
installed = True
615+
break
654616

655617
if not installed:
656618
# Fallback to adding a new middleware
657619
logger.info("Adding a new middleware to the FastAPI app.")
658-
app.add_middleware(
659-
RolloutAttemptMiddleware,
660-
store=weakref.ref(self.store),
661-
llm_proxy=weakref.ref(self),
662-
)
663-
664-
# Register callbacks once on the global LiteLLM callback list.
665-
if self._callbacks_initialized_copy is None:
666-
logger.info("Callbacks are not initialized. Initializing them.")
667-
self._callbacks_initialized_copy = cast(List[Any], litellm.callbacks) + [ # type: ignore
668-
AddReturnTokenIds(),
669-
LightningOpenTelemetry(self.store, self),
670-
]
620+
app.add_middleware(RolloutAttemptMiddleware)
671621

672-
else:
673-
logger.warning("Callbacks are already initialized. Augmenting the initialized copy with latest store.")
674-
675-
opentelemetry_callbacks = [cb for cb in litellm.callbacks if isinstance(cb, LightningOpenTelemetry) and cb.owned_by(self)] # type: ignore
676-
if len(opentelemetry_callbacks) > 1:
677-
raise RuntimeError(
678-
"Found multiple LightningOpenTelemetry callbacks for this LLMProxy instance. This is unsupported."
679-
)
680-
elif len(opentelemetry_callbacks) == 0:
681-
logger.error(
682-
"LightningOpenTelemetry callback not found in litellm.callbacks but the proxy has been initialized. This should not happen."
683-
)
684-
self._callbacks_initialized_copy.append(LightningOpenTelemetry(self.store, self))
685-
else:
686-
opentelemetry_callbacks[0].set_store(self.store)
687-
# The updated callback should also be reflected in _callbacks_initialized_copy
688-
689-
# Hacks to avoid issues on restart within the same process.
690-
_reset_litellm_logging_callback_manager()
691-
# Reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
622+
if not initialize_llm_callbacks():
623+
# If it's not the first time to initialize the callbacks, also
624+
# reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
692625
_reset_litellm_logging_worker()
693626

694-
litellm.callbacks.clear() # type: ignore
695-
litellm.callbacks.extend(self._callbacks_initialized_copy) # type: ignore
696-
697627
def start(self):
698628
"""Start the proxy server thread and initialize global wiring.
699629
@@ -711,7 +641,7 @@ def start(self):
711641
if not self.store:
712642
raise ValueError("Store is not set. Please set the store before starting the LLMProxy.")
713643

714-
# Initialize global middleware and callbacks once.
644+
# Initialize global middleware and callbacks.
715645
self.initialize()
716646

717647
# Persist a temp worker config for LiteLLM and point the proxy at it.
@@ -848,6 +778,60 @@ def as_resource(
848778
raise ValueError("Either rollout_id and attempt_id must be provided, or neither.")
849779

850780

781+
_global_llm_proxy: Optional[LLMProxy] = None
782+
_callbacks_before_litellm_start: Optional[List[Any]] = None
783+
784+
785+
def get_active_llm_proxy() -> LLMProxy:
786+
"""Get the current global LLMProxy instance.
787+
788+
Returns:
789+
Optional[LLMProxy]: The current LLMProxy if set, else None.
790+
"""
791+
if _global_llm_proxy is None:
792+
raise ValueError("Global LLMProxy is not set. Please call llm_proxy.start() first.")
793+
return _global_llm_proxy
794+
795+
796+
def set_active_llm_proxy(proxy: LLMProxy) -> None:
797+
"""Set the current global LLMProxy instance.
798+
799+
Args:
800+
proxy: The LLMProxy instance to set as global.
801+
"""
802+
global _global_llm_proxy
803+
_global_llm_proxy = proxy
804+
805+
806+
def initialize_llm_callbacks() -> bool:
807+
"""Restore `litellm.callbacks` to a state that is just initialized by agent-lightning.
808+
809+
When litellm is restarted multiple times in the same process, more and more callbacks
810+
will be appended to `litellm.callbacks`, which may exceed the MAX_CALLBACKS limit.
811+
This function remembers the initial state of `litellm.callbacks` and always restore to that state.
812+
813+
Returns:
814+
Whether the callbacks are initialized for the first time.
815+
"""
816+
global _callbacks_before_litellm_start
817+
818+
if _callbacks_before_litellm_start is None:
819+
litellm.callbacks.extend( # type: ignore
820+
[
821+
AddReturnTokenIds(),
822+
LightningOpenTelemetry(),
823+
]
824+
)
825+
_callbacks_before_litellm_start = [*litellm.callbacks] # type: ignore
826+
return True
827+
828+
_reset_litellm_logging_callback_manager()
829+
830+
litellm.callbacks.clear() # type: ignore
831+
litellm.callbacks.extend(_callbacks_before_litellm_start) # type: ignore
832+
return False
833+
834+
851835
def _get_default_ipv4_address() -> str:
852836
"""Determine the default outbound IPv4 address for this machine.
853837

0 commit comments

Comments
 (0)