Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions arbos/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import math
import os
import secrets
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -176,6 +177,7 @@ async def _evaluate_async(self, code: str) -> EvalResult:
logger.info(f" TTL: {self._ttl}s ({self._ttl / 60:.0f} min)")
logger.info(f" Name: {deploy_name}")

auth_token = secrets.token_urlsafe(32)
client = BasilicaClient()
deploy_kwargs = {
"name": deploy_name,
Expand All @@ -188,6 +190,7 @@ async def _evaluate_async(self, code: str) -> EvalResult:
"cpu": bc.get("cpu", "24"),
"memory": bc.get("memory", "480Gi"),
"timeout": 1800,
"env": {"EVAL_AUTH_TOKEN": auth_token},
}
if bc.get("interconnect"):
deploy_kwargs["interconnect"] = bc["interconnect"]
Expand Down Expand Up @@ -219,7 +222,8 @@ async def _evaluate_async(self, code: str) -> EvalResult:
logger.info(f" Timeout budget: {http_timeout}s")
eval_start = time.time()

async with httpx.AsyncClient(timeout=60) as submit_client:
auth_headers = {"Authorization": f"Bearer {auth_token}"}
async with httpx.AsyncClient(timeout=60, headers=auth_headers) as submit_client:
submit_resp = await submit_client.post(f"{deployment.url}/evaluate", json=payload)

if submit_resp.status_code not in (200, 202):
Expand All @@ -245,7 +249,7 @@ async def _evaluate_async(self, code: str) -> EvalResult:
consecutive_errors = 0
poll_timeout = httpx.Timeout(connect=10, read=60, write=10, pool=10)

async with httpx.AsyncClient(timeout=poll_timeout) as poll_client:
async with httpx.AsyncClient(timeout=poll_timeout, headers=auth_headers) as poll_client:
while True:
elapsed = time.time() - eval_start
if elapsed >= http_timeout:
Expand Down
40 changes: 33 additions & 7 deletions environments/templar/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import torch
import torch.nn.functional as F
from fastapi import FastAPI
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -2961,12 +2961,37 @@ def _timer_divergence(a: float, b: float) -> float:

app = FastAPI(title="Templar MFU Evaluation", version="2.0.0")

# ---------------------------------------------------------------------------
# Bearer-token auth — blocks rogue requests from other Basilica tenants
# ---------------------------------------------------------------------------
_AUTH_TOKEN: str | None = os.environ.get("EVAL_AUTH_TOKEN")
if _AUTH_TOKEN:
logger.info("Auth token configured — /evaluate and /eval-status require Bearer token")
else:
logger.warning("No EVAL_AUTH_TOKEN set — endpoints are open (local/dev mode)")


async def _verify_auth(request: Request) -> None:
"""FastAPI dependency — rejects unauthenticated requests before body parsing."""
if not _AUTH_TOKEN:
return
auth = request.headers.get("authorization", "")
if auth == f"Bearer {_AUTH_TOKEN}":
return
logger.warning(
f"Rejected unauthenticated request: {request.method} {request.url.path} "
f"from {request.client.host if request.client else 'unknown'}"
)
raise HTTPException(status_code=401, detail="unauthorized")


# ---------------------------------------------------------------------------
# In-memory jobs table for async evaluation
# ---------------------------------------------------------------------------
# Maps job_id -> {"status": "pending"|"done"|"failed", "result": dict|None}
_jobs: dict[str, dict] = {}
_jobs_lock = asyncio.Lock()
_background_tasks: set[asyncio.Task] = set()

# Global actor instance (reused for efficiency)
_actor: Actor | None = None
Expand Down Expand Up @@ -3287,8 +3312,8 @@ async def _evaluation_background(job_id: str, request: EvaluateRequest) -> None:
logger.error(f"[JOB {job_id}] Evaluation failed: {exc}")


@app.post("/evaluate")
async def evaluate(request: EvaluateRequest):
@app.post("/evaluate", dependencies=[Depends(_verify_auth)])
async def evaluate(body: EvaluateRequest):
"""Accept evaluation request, start in background, return job_id immediately.

Returns HTTP 202 with {"job_id": "..."} so the caller can poll
Expand All @@ -3298,15 +3323,16 @@ async def evaluate(request: EvaluateRequest):
job_id = _uuid.uuid4().hex
async with _jobs_lock:
_jobs[job_id] = {"status": "pending", "result": None}
asyncio.create_task(_evaluation_background(job_id, request))
task = asyncio.create_task(_evaluation_background(job_id, body))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info(
f"[JOB {job_id}] Evaluation accepted (task_id={request.task_id}, "
f"num_gpus={request.num_gpus})"
f"[JOB {job_id}] Evaluation accepted (task_id={body.task_id}, num_gpus={body.num_gpus})"
)
return JSONResponse(status_code=202, content={"job_id": job_id})


@app.get("/eval-status/{job_id}")
@app.get("/eval-status/{job_id}", dependencies=[Depends(_verify_auth)])
async def eval_status(job_id: str):
"""Poll for evaluation result.

Expand Down
7 changes: 5 additions & 2 deletions src/crusades/affinetes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ async def _evaluate_basilica(
)
logger.info(f" Deployment name: {deploy_name}")

auth_token = secrets.token_urlsafe(32)
client = BasilicaClient()
deploy_kwargs = {
"name": deploy_name,
Expand All @@ -770,6 +771,7 @@ async def _evaluate_basilica(
"cpu": self.basilica_cpu,
"memory": self.basilica_memory,
"timeout": 1800,
"env": {"EVAL_AUTH_TOKEN": auth_token},
}
if self.basilica_interconnect:
deploy_kwargs["interconnect"] = self.basilica_interconnect
Expand Down Expand Up @@ -824,7 +826,8 @@ async def _evaluate_basilica(
post_start = time.time()

# ── Step 1: Submit job (returns 202 + job_id immediately) ──
async with httpx.AsyncClient(timeout=60) as submit_client:
auth_headers = {"Authorization": f"Bearer {auth_token}"}
async with httpx.AsyncClient(timeout=60, headers=auth_headers) as submit_client:
submit_resp = await submit_client.post(f"{deployment.url}/evaluate", json=payload)

if submit_resp.status_code not in (200, 202):
Expand Down Expand Up @@ -852,7 +855,7 @@ async def _evaluate_basilica(
consecutive_errors = 0
poll_timeout = httpx.Timeout(connect=10, read=60, write=10, pool=10)

async with httpx.AsyncClient(timeout=poll_timeout) as poll_client:
async with httpx.AsyncClient(timeout=poll_timeout, headers=auth_headers) as poll_client:
while True:
elapsed = time.time() - post_start
if elapsed >= http_timeout:
Expand Down