Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
print_banner, print_run_summary, print_clinical_report,
)
from rmoe.ui import (
BOLD, CYAN, DIM, GREEN, MAGENTA, RED, RESET, YELLOW, WHITE,
_rule, _kv,
BOLD, DIM, GREEN, RED, RESET, YELLOW,
_rule,
)
from rmoe.charts import (
sc_progression_chart, ddx_evolution_chart, uncertainty_heatmap,
Expand Down Expand Up @@ -175,7 +175,7 @@ def main(argv=None) -> int:
from rmoe.eval import BenchmarkRunner, BenchmarkDataset
if not args.quiet:
print_banner()
print(f"\n Loading benchmark dataset …")
print("\n Loading benchmark dataset …")
dataset = BenchmarkDataset(args.benchmark_dataset)

sm = WannaStateMachine(hard_limit=args.max_iter or 3,
Expand Down
22 changes: 19 additions & 3 deletions rmoe/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,33 @@ def _load_prompt(path: str, fallback: str) -> str:
"iteration", "phase ", "protocol",
"metric", "wanna", "shows a primary",
"it seems", "attn score", " roi ",
# Additional patterns produced when the model outputs reasoning prose
# instead of a medical condition name (e.g. attention-map references,
# pipeline component names, or probability-related sentences).
"attention map", " attn ", "arll", " probability",
"how to", "the model", "approach this",
)

# Minimum character length for a string to be considered a medical diagnosis.
_MIN_DIAGNOSIS_LENGTH = 4


def _is_clinical_hypothesis(name: str) -> bool:
"""Return True only if *name* looks like a real medical diagnosis."""
if len(name.strip()) < _MIN_DIAGNOSIS_LENGTH:
"""Return True only if *name* looks like a real medical diagnosis.

Medical condition names (e.g. "Rib fracture", "Pulmonary adenocarcinoma")
always start with an uppercase letter. Strings that start with a lowercase
letter are partial sentences captured by the regex fallback and must be
rejected.
"""
stripped = name.strip()
if len(stripped) < _MIN_DIAGNOSIS_LENGTH:
return False
# All legitimate medical diagnoses start with an uppercase letter.
# Lowercase-initial strings are regex-captured sentence fragments.
if not stripped[0].isupper():
return False
low = name.lower()
low = stripped.lower()
return not any(s in low for s in _NON_CLINICAL_SUBSTRINGS)


Expand Down
175 changes: 175 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
tests/test_agents.py — Unit tests for rmoe.agents parsing helpers.

Covers _is_clinical_hypothesis and _parse_arll_output to ensure garbage
text fragments produced by the model (when it outputs prose instead of JSON)
are rejected and the fallback ensemble is used instead.
"""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import json
import pytest
from rmoe.agents import _is_clinical_hypothesis, _parse_arll_output


# ═══════════════════════════════════════════════════════════════════════════════
# _is_clinical_hypothesis
# ═══════════════════════════════════════════════════════════════════════════════

class TestIsClinicalHypothesis:
"""Garbage fragments from the problem statement must be rejected."""

# ── Real medical diagnoses must PASS ──────────────────────────────────────

def test_rib_fracture(self):
assert _is_clinical_hypothesis("Rib fracture") is True

def test_pulmonary_adenocarcinoma(self):
assert _is_clinical_hypothesis("Pulmonary adenocarcinoma") is True

def test_pneumothorax(self):
assert _is_clinical_hypothesis("Pneumothorax") is True

def test_community_acquired_pneumonia(self):
assert _is_clinical_hypothesis("Community-acquired pneumonia") is True

def test_wrist_fracture(self):
assert _is_clinical_hypothesis("Wrist fracture") is True

def test_tb_reactivation(self):
assert _is_clinical_hypothesis("Tuberculosis reactivation") is True

def test_pleural_effusion(self):
assert _is_clinical_hypothesis("Pleural effusion") is True

# ── Garbage fragments from broken-hand run must FAIL ─────────────────────

def test_garbage_with_probability(self):
"""'with probability' starts with lowercase → must be rejected."""
assert _is_clinical_hypothesis("with probability") is False

def test_garbage_spatial_attention_map(self):
"""Lowercase start + attention map reference → must be rejected."""
assert _is_clinical_hypothesis(
"t the spatial attention map has a attn of"
) is False

def test_garbage_trying_to_understand(self):
"""Lowercase start + ARLL pipeline reference → must be rejected."""
assert _is_clinical_hypothesis(
"m trying to understand how the ARLL Phase"
) is False

def test_garbage_attention_map_uppercase(self):
"""'Spatial Attention Map' – contains 'attention map' → must be rejected."""
assert _is_clinical_hypothesis(
"e a Spatial Attention Map with an attn of"
) is False

def test_garbage_approach_arll(self):
"""Partial sentence referencing ARLL → must be rejected."""
assert _is_clinical_hypothesis(
"igure out how to approach this ARLL Phase"
) is False

# ── Edge cases ────────────────────────────────────────────────────────────

def test_too_short(self):
"""Names shorter than 4 characters are rejected."""
assert _is_clinical_hypothesis("Flu") is False

def test_empty_string(self):
assert _is_clinical_hypothesis("") is False

def test_single_char(self):
assert _is_clinical_hypothesis("X") is False

def test_arll_in_name(self):
"""Any name containing 'arll' is treated as pipeline meta-text."""
assert _is_clinical_hypothesis("ARLL phase output") is False

def test_attention_map_in_name(self):
"""Any name containing 'attention map' is pipeline meta-text."""
assert _is_clinical_hypothesis("Attention map analysis") is False


# ═══════════════════════════════════════════════════════════════════════════════
# _parse_arll_output — fallback triggered when model outputs prose
# ═══════════════════════════════════════════════════════════════════════════════

class TestParseArllOutput:
"""When ARLL outputs prose (no valid JSON), the fallback ensemble must be
used rather than garbage hypothesis names."""

def test_valid_json_parsed_correctly(self):
"""Clean JSON output is parsed without fallback."""
raw = json.dumps({
"cot": "Step 1 — fracture analysis …",
"ddx": [
{"diagnosis": "Rib fracture", "probability": 0.75, "evidence": "cortical break"},
{"diagnosis": "Pneumothorax", "probability": 0.15, "evidence": "no lung marking"},
{"diagnosis": "Pulmonary contusion", "probability": 0.10, "evidence": ""},
],
"sigma2": 0.07,
"sc": 0.93,
"wanna": False,
"feedback_request": None,
"feedback_payload": None,
"rag_references": [],
"temporal_note": None,
})
out = _parse_arll_output(raw)
assert len(out.ensemble.hypotheses) == 3
assert out.ensemble.hypotheses[0].diagnosis == "Rib fracture"
assert out.ensemble.hypotheses[0].probability == 0.75

def test_garbage_prose_yields_empty_ensemble(self):
"""Prose output (no JSON, no valid diagnosis:number pairs) gives empty
ensemble, triggering the fallback in ReasoningExpert.execute()."""
raw = (
"I'm trying to understand how the ARLL Phase works. "
"The spatial attention map has a attn of 0.100. "
"m trying to parse with probability 0.950. "
"arll phase approach 0.020."
)
out = _parse_arll_output(raw)
# All regex-matched candidates are garbage (lowercase start or
# containing non-clinical substrings) → ensemble must be empty
assert out.ensemble.hypotheses == []

def test_mixed_prose_valid_diagnosis(self):
"""When prose contains both garbage and a real diagnosis:prob pair,
only the real diagnosis survives."""
raw = (
"Based on the analysis, with probability 0.90 we see findings. "
"Rib fracture: 0.75 — cortical break visible. "
"arll output: 0.10"
)
out = _parse_arll_output(raw)
names = [h.diagnosis for h in out.ensemble.hypotheses]
assert "Rib fracture" in names
# Garbage entries must NOT be present
for name in names:
assert name[0].isupper(), f"Non-uppercase diagnosis slipped through: {name!r}"

def test_json_with_garbage_diagnosis_field_filtered(self):
"""JSON block where the 'diagnosis' field contains meta-text is filtered."""
raw = json.dumps({
"cot": "some cot",
"ddx": [
{"diagnosis": "with probability", "probability": 0.95, "evidence": ""},
{"diagnosis": "Rib fracture", "probability": 0.05, "evidence": "break"},
],
"sigma2": 0.10,
"sc": 0.90,
"wanna": False,
"feedback_request": None,
"feedback_payload": None,
"rag_references": [],
"temporal_note": None,
})
out = _parse_arll_output(raw)
names = [h.diagnosis for h in out.ensemble.hypotheses]
assert "with probability" not in names
assert "Rib fracture" in names