Skip to content

Commit 4cf7e5a

Browse files
ref(anthropic): Factor out streamed result handling
1 parent a4e4c57 commit 4cf7e5a

File tree

1 file changed

+98
-86
lines changed

1 file changed

+98
-86
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 98 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
from sentry_sdk.tracing import Span
5050
from sentry_sdk._types import TextPart
5151

52+
from anthropic import AsyncStream
53+
from anthropic.types import RawMessageStreamEvent
54+
5255

5356
class _RecordedUsage:
5457
output_tokens: int = 0
@@ -389,6 +392,96 @@ def _set_output_data(
389392
span.__exit__(None, None, None)
390393

391394

395+
def _set_streaming_output_data(
396+
result: "AsyncStream[RawMessageStreamEvent]",
397+
span: "sentry_sdk.tracing.Span",
398+
):
399+
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
400+
401+
old_iterator = result._iterator
402+
403+
def new_iterator() -> "Iterator[MessageStreamEvent]":
404+
model = None
405+
usage = _RecordedUsage()
406+
content_blocks: "list[str]" = []
407+
408+
for event in old_iterator:
409+
(
410+
model,
411+
usage,
412+
content_blocks,
413+
) = _collect_ai_data(
414+
event,
415+
model,
416+
usage,
417+
content_blocks,
418+
)
419+
yield event
420+
421+
# Anthropic's input_tokens excludes cached/cache_write tokens.
422+
# Normalize to total input tokens for correct cost calculations.
423+
total_input = (
424+
usage.input_tokens
425+
+ (usage.cache_read_input_tokens or 0)
426+
+ (usage.cache_write_input_tokens or 0)
427+
)
428+
429+
_set_output_data(
430+
span=span,
431+
integration=integration,
432+
model=model,
433+
input_tokens=total_input,
434+
output_tokens=usage.output_tokens,
435+
cache_read_input_tokens=usage.cache_read_input_tokens,
436+
cache_write_input_tokens=usage.cache_write_input_tokens,
437+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
438+
finish_span=True,
439+
)
440+
441+
async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
442+
model = None
443+
usage = _RecordedUsage()
444+
content_blocks: "list[str]" = []
445+
446+
async for event in old_iterator:
447+
(
448+
model,
449+
usage,
450+
content_blocks,
451+
) = _collect_ai_data(
452+
event,
453+
model,
454+
usage,
455+
content_blocks,
456+
)
457+
yield event
458+
459+
# Anthropic's input_tokens excludes cached/cache_write tokens.
460+
# Normalize to total input tokens for correct cost calculations.
461+
total_input = (
462+
usage.input_tokens
463+
+ (usage.cache_read_input_tokens or 0)
464+
+ (usage.cache_write_input_tokens or 0)
465+
)
466+
467+
_set_output_data(
468+
span=span,
469+
integration=integration,
470+
model=model,
471+
input_tokens=total_input,
472+
output_tokens=usage.output_tokens,
473+
cache_read_input_tokens=usage.cache_read_input_tokens,
474+
cache_write_input_tokens=usage.cache_write_input_tokens,
475+
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
476+
finish_span=True,
477+
)
478+
479+
if str(type(result._iterator)) == "<class 'async_generator'>":
480+
result._iterator = new_iterator_async()
481+
else:
482+
result._iterator = new_iterator()
483+
484+
392485
def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
393486
integration = kwargs.pop("integration")
394487
if integration is None:
@@ -415,6 +508,11 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
415508

416509
result = yield f, args, kwargs
417510

511+
is_streaming_response = kwargs.get("stream", False)
512+
if is_streaming_response:
513+
_set_streaming_output_data(result, span)
514+
return result
515+
418516
with capture_internal_exceptions():
419517
if hasattr(result, "content"):
420518
(
@@ -444,92 +542,6 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
444542
content_blocks=content_blocks,
445543
finish_span=True,
446544
)
447-
448-
# Streaming response
449-
elif hasattr(result, "_iterator"):
450-
old_iterator = result._iterator
451-
452-
def new_iterator() -> "Iterator[MessageStreamEvent]":
453-
model = None
454-
usage = _RecordedUsage()
455-
content_blocks: "list[str]" = []
456-
457-
for event in old_iterator:
458-
(
459-
model,
460-
usage,
461-
content_blocks,
462-
) = _collect_ai_data(
463-
event,
464-
model,
465-
usage,
466-
content_blocks,
467-
)
468-
yield event
469-
470-
# Anthropic's input_tokens excludes cached/cache_write tokens.
471-
# Normalize to total input tokens for correct cost calculations.
472-
total_input = (
473-
usage.input_tokens
474-
+ (usage.cache_read_input_tokens or 0)
475-
+ (usage.cache_write_input_tokens or 0)
476-
)
477-
478-
_set_output_data(
479-
span=span,
480-
integration=integration,
481-
model=model,
482-
input_tokens=total_input,
483-
output_tokens=usage.output_tokens,
484-
cache_read_input_tokens=usage.cache_read_input_tokens,
485-
cache_write_input_tokens=usage.cache_write_input_tokens,
486-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
487-
finish_span=True,
488-
)
489-
490-
async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
491-
model = None
492-
usage = _RecordedUsage()
493-
content_blocks: "list[str]" = []
494-
495-
async for event in old_iterator:
496-
(
497-
model,
498-
usage,
499-
content_blocks,
500-
) = _collect_ai_data(
501-
event,
502-
model,
503-
usage,
504-
content_blocks,
505-
)
506-
yield event
507-
508-
# Anthropic's input_tokens excludes cached/cache_write tokens.
509-
# Normalize to total input tokens for correct cost calculations.
510-
total_input = (
511-
usage.input_tokens
512-
+ (usage.cache_read_input_tokens or 0)
513-
+ (usage.cache_write_input_tokens or 0)
514-
)
515-
516-
_set_output_data(
517-
span=span,
518-
integration=integration,
519-
model=model,
520-
input_tokens=total_input,
521-
output_tokens=usage.output_tokens,
522-
cache_read_input_tokens=usage.cache_read_input_tokens,
523-
cache_write_input_tokens=usage.cache_write_input_tokens,
524-
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
525-
finish_span=True,
526-
)
527-
528-
if str(type(result._iterator)) == "<class 'async_generator'>":
529-
result._iterator = new_iterator_async()
530-
else:
531-
result._iterator = new_iterator()
532-
533545
else:
534546
span.set_data("unknown_response", True)
535547
span.__exit__(None, None, None)

0 commit comments

Comments
 (0)