Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
from .run_state import RunState
from .tool import dispose_resolved_computers
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.context import TraceCtxManager
from .tracing import Span, SpanError, agent_span, get_current_trace
from .tracing.context import TraceCtxManager, create_trace_for_run
from .tracing.span_data import AgentSpanData
from .util import _error_tracing

Expand Down Expand Up @@ -549,6 +549,8 @@ async def run(
metadata=trace_metadata,
tracing=trace_config,
disabled=run_config.tracing_disabled,
trace_state=run_state._trace_state if run_state is not None else None,
reattach_resumed_trace=is_resumed_state,
):
if is_resumed_state and run_state is not None:
run_state.set_trace(get_current_trace())
Expand Down Expand Up @@ -1519,17 +1521,15 @@ def run_streamed(
# If there's already a trace, we don't create a new one. In addition, we can't end the
# trace here, because the actual work is done in `stream_events` and this method ends
# before that.
new_trace = (
None
if get_current_trace()
else trace(
workflow_name=trace_workflow_name,
trace_id=trace_id,
group_id=trace_group_id,
metadata=trace_metadata,
tracing=trace_config,
disabled=run_config.tracing_disabled,
)
new_trace = create_trace_for_run(
workflow_name=trace_workflow_name,
trace_id=trace_id,
group_id=trace_group_id,
metadata=trace_metadata,
tracing=trace_config,
disabled=run_config.tracing_disabled,
trace_state=run_state._trace_state if run_state is not None else None,
reattach_resumed_trace=is_resumed_state,
)
if run_state is not None:
run_state.set_trace(new_trace or get_current_trace())
Expand Down
4 changes: 2 additions & 2 deletions src/agents/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@
# 2. Keep older readable versions in SUPPORTED_SCHEMA_VERSIONS for backward reads.
# 3. to_json() always emits CURRENT_SCHEMA_VERSION.
# 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer versions).
CURRENT_SCHEMA_VERSION = "1.2"
SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", CURRENT_SCHEMA_VERSION})
CURRENT_SCHEMA_VERSION = "1.3"
SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", CURRENT_SCHEMA_VERSION})

_FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput)
_COMPUTER_OUTPUT_ADAPTER: TypeAdapter[ComputerCallOutput] = TypeAdapter(ComputerCallOutput)
Expand Down
101 changes: 90 additions & 11 deletions src/agents/tracing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,81 @@

from .config import TracingConfig
from .create import get_current_trace, trace
from .traces import Trace
from .traces import Trace, TraceState, _hash_tracing_api_key, reattach_trace


def _get_tracing_api_key(tracing: TracingConfig | None) -> str | None:
return tracing.get("api_key") if tracing is not None else None


def _trace_state_matches_effective_settings(
*,
trace_state: TraceState,
workflow_name: str,
trace_id: str | None,
group_id: str | None,
metadata: dict[str, Any] | None,
tracing: TracingConfig | None,
) -> bool:
if trace_state.trace_id is None or trace_state.trace_id != trace_id:
return False
if trace_state.workflow_name != workflow_name:
return False
if trace_state.group_id != group_id:
return False
if trace_state.metadata != metadata:
return False
tracing_api_key = _get_tracing_api_key(tracing)
if trace_state.tracing_api_key is not None:
return trace_state.tracing_api_key == tracing_api_key
if trace_state.tracing_api_key_hash is not None:
# A fingerprint lets stripped RunState snapshots prove the caller
# re-supplied the same explicit key.
return trace_state.tracing_api_key_hash == _hash_tracing_api_key(tracing_api_key)
return tracing_api_key is None


def create_trace_for_run(
*,
workflow_name: str,
trace_id: str | None,
group_id: str | None,
metadata: dict[str, Any] | None,
tracing: TracingConfig | None,
disabled: bool,
trace_state: TraceState | None = None,
reattach_resumed_trace: bool = False,
) -> Trace | None:
"""Return a trace object for this run when one is not already active."""
current_trace = get_current_trace()
if current_trace:
return None

if (
reattach_resumed_trace
and not disabled
and trace_state is not None
and _trace_state_matches_effective_settings(
trace_state=trace_state,
workflow_name=workflow_name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
tracing=tracing,
)
):
# Reuse the live key because secure snapshots may persist only the
# fingerprint, not the secret itself.
return reattach_trace(trace_state, tracing_api_key=_get_tracing_api_key(tracing))

return trace(
workflow_name=workflow_name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
tracing=tracing,
disabled=disabled,
)


class TraceCtxManager:
Expand All @@ -18,6 +92,8 @@ def __init__(
metadata: dict[str, Any] | None,
tracing: TracingConfig | None,
disabled: bool,
trace_state: TraceState | None = None,
reattach_resumed_trace: bool = False,
):
self.trace: Trace | None = None
self.workflow_name = workflow_name
Expand All @@ -26,18 +102,21 @@ def __init__(
self.metadata = metadata
self.tracing = tracing
self.disabled = disabled
self.trace_state = trace_state
self.reattach_resumed_trace = reattach_resumed_trace

def __enter__(self) -> TraceCtxManager:
current_trace = get_current_trace()
if not current_trace:
self.trace = trace(
workflow_name=self.workflow_name,
trace_id=self.trace_id,
group_id=self.group_id,
metadata=self.metadata,
tracing=self.tracing,
disabled=self.disabled,
)
self.trace = create_trace_for_run(
workflow_name=self.workflow_name,
trace_id=self.trace_id,
group_id=self.group_id,
metadata=self.metadata,
tracing=self.tracing,
disabled=self.disabled,
trace_state=self.trace_state,
reattach_resumed_trace=self.reattach_resumed_trace,
)
if self.trace:
assert self.trace is not None
self.trace.start(mark_as_current=True)
return self
Expand Down
144 changes: 143 additions & 1 deletion src/agents/tracing/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import abc
import contextvars
import hashlib
import threading
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -148,6 +150,14 @@ def to_json(self, *, include_tracing_api_key: bool = False) -> dict[str, Any] |
return payload


def _hash_tracing_api_key(tracing_api_key: str | None) -> str | None:
# Persist only a fingerprint so resumed runs can verify the same explicit
# tracing key without storing the secret.
if tracing_api_key is None:
return None
return hashlib.sha256(tracing_api_key.encode("utf-8")).hexdigest()


@dataclass
class TraceState:
"""Serializable trace metadata for run state persistence."""
Expand All @@ -157,6 +167,7 @@ class TraceState:
group_id: str | None = None
metadata: dict[str, Any] | None = None
tracing_api_key: str | None = None
tracing_api_key_hash: str | None = None
object_type: str | None = None
extra: dict[str, Any] = field(default_factory=dict)

Expand All @@ -179,12 +190,20 @@ def from_json(cls, payload: Mapping[str, Any] | None) -> TraceState | None:
metadata_value = data.pop("metadata", None)
metadata = metadata_value if isinstance(metadata_value, dict) else None
tracing_api_key = data.pop("tracing_api_key", None)
tracing_api_key_hash = data.pop("tracing_api_key_hash", None)
resolved_tracing_api_key = tracing_api_key if isinstance(tracing_api_key, str) else None
resolved_tracing_api_key_hash = _hash_tracing_api_key(resolved_tracing_api_key)
# Secure snapshots may strip the raw key, so keep the stored
# fingerprint for resume-time matching.
if resolved_tracing_api_key_hash is None and isinstance(tracing_api_key_hash, str):
resolved_tracing_api_key_hash = tracing_api_key_hash
return cls(
trace_id=trace_id if isinstance(trace_id, str) else None,
workflow_name=workflow_name if isinstance(workflow_name, str) else None,
group_id=group_id if isinstance(group_id, str) else None,
metadata=metadata,
tracing_api_key=tracing_api_key if isinstance(tracing_api_key, str) else None,
tracing_api_key=resolved_tracing_api_key,
tracing_api_key_hash=resolved_tracing_api_key_hash,
object_type=object_type if isinstance(object_type, str) else None,
extra=data,
)
Expand All @@ -196,6 +215,7 @@ def to_json(self, *, include_tracing_api_key: bool = False) -> dict[str, Any] |
and self.group_id is None
and self.metadata is None
and self.tracing_api_key is None
and self.tracing_api_key_hash is None
and self.object_type is None
and not self.extra
):
Expand All @@ -213,12 +233,133 @@ def to_json(self, *, include_tracing_api_key: bool = False) -> dict[str, Any] |
payload["metadata"] = dict(self.metadata)
if include_tracing_api_key and self.tracing_api_key:
payload["tracing_api_key"] = self.tracing_api_key
if self.tracing_api_key_hash:
# Always persist the fingerprint so default RunState snapshots
# can still validate explicit resume keys.
payload["tracing_api_key_hash"] = self.tracing_api_key_hash
for key, value in self.extra.items():
if key not in payload:
payload[key] = value
return payload


_started_trace_ids: set[str] = set()
_started_trace_ids_lock = threading.Lock()


def _mark_trace_id_started(trace_id: str | None) -> None:
if not trace_id or trace_id == "no-op":
return
with _started_trace_ids_lock:
_started_trace_ids.add(trace_id)


def _trace_id_was_started(trace_id: str | None) -> bool:
if not trace_id or trace_id == "no-op":
return False
with _started_trace_ids_lock:
return trace_id in _started_trace_ids


class ReattachedTrace(Trace):
"""A trace context rebuilt from persisted state without re-emitting trace start events."""

__slots__ = (
"_name",
"_trace_id",
"_tracing_api_key",
"group_id",
"metadata",
"_prev_context_token",
"_started",
)

def __init__(
self,
*,
name: str,
trace_id: str,
group_id: str | None,
metadata: dict[str, Any] | None,
tracing_api_key: str | None,
) -> None:
self._name = name
self._trace_id = trace_id
self._tracing_api_key = tracing_api_key
self.group_id = group_id
self.metadata = metadata
self._prev_context_token: contextvars.Token[Trace | None] | None = None
self._started = False

@property
def trace_id(self) -> str:
return self._trace_id

@property
def name(self) -> str:
return self._name

@property
def tracing_api_key(self) -> str | None:
return self._tracing_api_key

def start(self, mark_as_current: bool = False):
if self._started:
return

self._started = True
_mark_trace_id_started(self.trace_id)

if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)

def finish(self, reset_current: bool = False):
if not self._started:
return

if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
self._prev_context_token = None

def __enter__(self) -> Trace:
if self._started:
if not self._prev_context_token:
logger.error("Trace already started but no context token set")
return self

self.start(mark_as_current=True)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.finish(reset_current=exc_type is not GeneratorExit)

def export(self) -> dict[str, Any] | None:
return {
"object": "trace",
"id": self.trace_id,
"workflow_name": self.name,
"group_id": self.group_id,
"metadata": self.metadata,
}


def reattach_trace(trace_state: TraceState, *, tracing_api_key: str | None = None) -> Trace | None:
"""Build a live trace context from persisted state without notifying processors."""
if trace_state.trace_id is None:
return None
return ReattachedTrace(
name=trace_state.workflow_name or "Agent workflow",
trace_id=trace_state.trace_id,
group_id=trace_state.group_id,
metadata=dict(trace_state.metadata) if trace_state.metadata is not None else None,
tracing_api_key=(
trace_state.tracing_api_key
if trace_state.tracing_api_key is not None
else tracing_api_key
),
)


class NoOpTrace(Trace):
"""A no-op implementation of Trace that doesn't record any data.

Expand Down Expand Up @@ -347,6 +488,7 @@ def start(self, mark_as_current: bool = False):

self._started = True
self._processor.on_trace_start(self)
_mark_trace_id_started(self.trace_id)

if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)
Expand Down
Loading