Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sdk/python/agentfield/harness/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
import asyncio
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple

_ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")


def strip_ansi(text: str) -> str:
return _ANSI_RE.sub("", text)


async def run_cli(
cmd: List[str],
Expand Down Expand Up @@ -40,7 +47,7 @@ async def run_cli(
return (
stdout_bytes.decode("utf-8", errors="replace"),
stderr_bytes.decode("utf-8", errors="replace"),
proc.returncode or 0,
proc.returncode if proc.returncode is not None else -1,
)


Expand Down
24 changes: 24 additions & 0 deletions sdk/python/agentfield/harness/_result.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional


class FailureType(str, Enum):
"""Classifies how a harness invocation failed.

Providers set this on RawResult so the runner can decide retry strategy:
- ``none``: No failure.
- ``crash``: Process killed by signal or non-zero exit with no output.
- ``timeout``: Execution exceeded the time limit.
- ``api_error``: Transient API-level error (rate limit, 5xx, etc.).
- ``no_output``: Process exited OK but produced no output file.
- ``schema``: Output file exists but fails schema validation.
"""

NONE = "none"
CRASH = "crash"
TIMEOUT = "timeout"
API_ERROR = "api_error"
NO_OUTPUT = "no_output"
SCHEMA = "schema"


@dataclass
class Metrics:
duration_ms: int = 0
Expand All @@ -21,6 +42,8 @@ class RawResult:
metrics: Metrics = field(default_factory=Metrics)
is_error: bool = False
error_message: Optional[str] = None
failure_type: FailureType = FailureType.NONE
returncode: Optional[int] = None


@dataclass
Expand All @@ -29,6 +52,7 @@ class HarnessResult:
parsed: Any = None
is_error: bool = False
error_message: Optional[str] = None
failure_type: FailureType = FailureType.NONE
cost_usd: Optional[float] = None
num_turns: int = 0
duration_ms: int = 0
Expand Down
202 changes: 181 additions & 21 deletions sdk/python/agentfield/harness/_runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from __future__ import annotations

import asyncio
import logging
import os
import random
import time
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from agentfield.harness._result import HarnessResult, RawResult
from agentfield.harness._result import FailureType, HarnessResult, RawResult
from agentfield.harness._schema import (
build_followup_prompt,
build_prompt_suffix,
cleanup_temp_files,
diagnose_output_failure,
get_output_path,
parse_and_validate,
)
from agentfield.harness.providers._base import HarnessProvider
from agentfield.harness.providers._factory import build_provider

logger = logging.getLogger(__name__)

TRANSIENT_PATTERNS = {
"rate limit",
"rate_limit",
Expand All @@ -32,6 +38,8 @@
"500",
}

DEFAULT_SCHEMA_RETRIES = 2


def _is_transient(error_str: str) -> bool:
lower = error_str.lower()
Expand All @@ -57,9 +65,12 @@ def _resolve_options(
"system_prompt",
"env",
"cwd",
"project_dir",
"codex_bin",
"gemini_bin",
"opencode_bin",
"opencode_server",
"schema_max_retries",
]:
val = getattr(config, field_name, None)
if val is not None:
Expand All @@ -71,6 +82,25 @@ def _resolve_options(
return options


def _accumulate_metrics(
all_raws: List[RawResult],
) -> tuple[Optional[float], int, str, List[Dict[str, Any]]]:
total_cost: Optional[float] = None
total_turns = 0
session_id = ""
all_messages: List[Dict[str, Any]] = []

for raw in all_raws:
if raw.metrics.total_cost_usd is not None:
total_cost = (total_cost or 0.0) + raw.metrics.total_cost_usd
total_turns += raw.metrics.num_turns
if raw.metrics.session_id:
session_id = raw.metrics.session_id
all_messages.extend(raw.messages)

return total_cost, total_turns, session_id, all_messages


class HarnessRunner:
def __init__(self, config: Optional[Any] = None):
self._config = config
Expand Down Expand Up @@ -115,9 +145,22 @@ async def run(
resolved_cwd = str(options.get("cwd", "."))
provider_instance = self._build_provider(str(resolved_provider), options)

# When project_dir is set (opencode provider), place the output file
# inside project_dir so the coding agent's Write tool can reach it.
# Use a unique subdir to avoid collisions from parallel calls.
project_dir = options.get("project_dir")
output_dir = resolved_cwd
_temp_output_dir: Optional[str] = None
if isinstance(project_dir, str) and project_dir:
import tempfile as _tempfile

_temp_output_dir = _tempfile.mkdtemp(prefix=".secaf-out-", dir=project_dir)
output_dir = _temp_output_dir

effective_prompt = prompt
if schema is not None:
effective_prompt = prompt + build_prompt_suffix(schema, resolved_cwd)
effective_prompt = prompt + build_prompt_suffix(schema, output_dir)
options["_original_prompt"] = effective_prompt

start_time = time.monotonic()
try:
Expand All @@ -126,11 +169,13 @@ async def run(
)

if schema is not None:
return self._handle_schema_output(
return await self._handle_schema_with_retry(
raw,
schema,
resolved_cwd,
output_dir,
start_time,
provider_instance,
options,
)

elapsed = int((time.monotonic() - start_time) * 1000)
Expand All @@ -139,6 +184,7 @@ async def run(
parsed=None,
is_error=raw.is_error,
error_message=raw.error_message,
failure_type=raw.failure_type,
cost_usd=raw.metrics.total_cost_usd,
num_turns=raw.metrics.num_turns,
duration_ms=elapsed,
Expand All @@ -147,7 +193,11 @@ async def run(
)
finally:
if schema is not None:
cleanup_temp_files(resolved_cwd)
cleanup_temp_files(output_dir)
if _temp_output_dir:
import shutil as _shutil

_shutil.rmtree(_temp_output_dir, ignore_errors=True)

def _build_provider(
self, provider_name: str, options: Dict[str, Any]
Expand Down Expand Up @@ -199,37 +249,147 @@ async def _execute_with_retry(
raise last_error
return RawResult(is_error=True, error_message="Max retries exceeded")

def _handle_schema_output(
async def _handle_schema_with_retry(
self,
raw: RawResult,
initial_raw: RawResult,
schema: Any,
cwd: str,
start_time: float,
provider: HarnessProvider,
options: Dict[str, Any],
) -> HarnessResult:
output_path = get_output_path(cwd)
validated = parse_and_validate(output_path, schema)
elapsed = int((time.monotonic() - start_time) * 1000)
schema_max_retries = int(
options.get("schema_max_retries", DEFAULT_SCHEMA_RETRIES)
)

all_raws: List[RawResult] = [initial_raw]

validated = parse_and_validate(output_path, schema)
if validated is not None:
elapsed = int((time.monotonic() - start_time) * 1000)
cost, turns, sid, msgs = _accumulate_metrics(all_raws)
return HarnessResult(
result=raw.result,
result=initial_raw.result,
parsed=validated,
is_error=False,
cost_usd=raw.metrics.total_cost_usd,
num_turns=raw.metrics.num_turns,
cost_usd=cost,
num_turns=turns,
duration_ms=elapsed,
session_id=raw.metrics.session_id,
messages=raw.messages,
session_id=sid,
messages=msgs,
)

_retryable = {FailureType.CRASH, FailureType.NO_OUTPUT, FailureType.NONE}
if (
initial_raw.is_error
and not os.path.exists(output_path)
and initial_raw.failure_type not in _retryable
) or (
schema_max_retries == 0
and initial_raw.is_error
and not os.path.exists(output_path)
):
elapsed = int((time.monotonic() - start_time) * 1000)
cost, turns, sid, msgs = _accumulate_metrics(all_raws)
provider_error = initial_raw.error_message or "Provider execution failed."
return HarnessResult(
result=initial_raw.result,
parsed=None,
is_error=True,
error_message=(
f"{provider_error} Output file was not created at {output_path}."
),
failure_type=initial_raw.failure_type,
cost_usd=cost,
num_turns=turns,
duration_ms=elapsed,
session_id=sid,
messages=msgs,
)

last_session_id = initial_raw.metrics.session_id

for retry_num in range(schema_max_retries):
if retry_num > 0:
await asyncio.sleep(min(0.5 * (2 ** (retry_num - 1)), 5.0))

is_crash = all_raws[
-1
].failure_type == FailureType.CRASH and not os.path.exists(output_path)
if is_crash:
original_prompt = options.get("_original_prompt", "")
retry_prompt = (
original_prompt
if original_prompt
else build_followup_prompt(
diagnose_output_failure(output_path, schema), cwd, schema
)
)
else:
error_detail = diagnose_output_failure(output_path, schema)
retry_prompt = build_followup_prompt(error_detail, cwd, schema)

detail_for_log = diagnose_output_failure(output_path, schema)

logger.info(
"Schema validation retry %d/%d: %s",
retry_num + 1,
schema_max_retries,
detail_for_log[:200],
)

retry_options = dict(options)
if last_session_id and not is_crash:
retry_options["resume_session_id"] = last_session_id

retry_raw = await self._execute_with_retry(
provider, retry_prompt, retry_options
)
all_raws.append(retry_raw)

if retry_raw.metrics.session_id:
last_session_id = retry_raw.metrics.session_id

if retry_raw.is_error:
logger.warning(
"Schema retry %d provider error: %s",
retry_num + 1,
retry_raw.error_message,
)
continue

validated = parse_and_validate(output_path, schema)
if validated is not None:
elapsed = int((time.monotonic() - start_time) * 1000)
cost, turns, sid, msgs = _accumulate_metrics(all_raws)
logger.info("Schema validation succeeded on retry %d", retry_num + 1)
return HarnessResult(
result=retry_raw.result,
parsed=validated,
is_error=False,
cost_usd=cost,
num_turns=turns,
duration_ms=elapsed,
session_id=sid,
messages=msgs,
)

elapsed = int((time.monotonic() - start_time) * 1000)
cost, turns, sid, msgs = _accumulate_metrics(all_raws)
final_diagnosis = diagnose_output_failure(output_path, schema)
return HarnessResult(
result=raw.result,
result=all_raws[-1].result,
parsed=None,
is_error=True,
error_message="Schema validation failed after parse and cosmetic repair attempts.",
cost_usd=raw.metrics.total_cost_usd,
num_turns=raw.metrics.num_turns,
error_message=(
f"Schema validation failed after {schema_max_retries} "
f"retry attempt(s). Last error: {final_diagnosis}"
),
failure_type=FailureType.SCHEMA,
cost_usd=cost,
num_turns=turns,
duration_ms=elapsed,
session_id=raw.metrics.session_id,
messages=raw.messages,
session_id=sid,
messages=msgs,
)
Loading