diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 2c66d9ad12..80b3f27986 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -119,9 +119,10 @@ pub fn find_family_for_model(mut slug: &str) -> Option { reasoning_summary_format: ReasoningSummaryFormat::Experimental, base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), experimental_supported_tools: vec![ - "read_file".to_string(), + "grep_files".to_string(), "list_dir".to_string(), - "test_sync_tool".to_string() + "read_file".to_string(), + "test_sync_tool".to_string(), ], supports_parallel_tool_calls: true, ) @@ -134,7 +135,11 @@ pub fn find_family_for_model(mut slug: &str) -> Option { reasoning_summary_format: ReasoningSummaryFormat::Experimental, base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), apply_patch_tool_type: Some(ApplyPatchToolType::Freeform), - experimental_supported_tools: vec!["read_file".to_string(), "list_dir".to_string()], + experimental_supported_tools: vec![ + "grep_files".to_string(), + "list_dir".to_string(), + "read_file".to_string(), + ], supports_parallel_tool_calls: true, ) diff --git a/codex-rs/core/src/tools/handlers/grep_files.rs b/codex-rs/core/src/tools/handlers/grep_files.rs new file mode 100644 index 0000000000..de3cd3411c --- /dev/null +++ b/codex-rs/core/src/tools/handlers/grep_files.rs @@ -0,0 +1,272 @@ +use std::path::Path; +use std::time::Duration; + +use async_trait::async_trait; +use serde::Deserialize; +use tokio::process::Command; +use tokio::time::timeout; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct GrepFilesHandler; + +const DEFAULT_LIMIT: usize = 100; +const MAX_LIMIT: usize = 2000; +const COMMAND_TIMEOUT: Duration = Duration::from_secs(30); + +fn default_limit() -> usize { + DEFAULT_LIMIT +} + +#[derive(Deserialize)] +struct GrepFilesArgs { + pattern: String, + #[serde(default)] + include: Option, + #[serde(default)] + path: Option, + #[serde(default = "default_limit")] + limit: usize, +} + +#[async_trait] +impl ToolHandler for GrepFilesHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { payload, turn, .. } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "grep_files handler received unsupported payload".to_string(), + )); + } + }; + + let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {err:?}" + )) + })?; + + let pattern = args.pattern.trim(); + if pattern.is_empty() { + return Err(FunctionCallError::RespondToModel( + "pattern must not be empty".to_string(), + )); + } + + if args.limit == 0 { + return Err(FunctionCallError::RespondToModel( + "limit must be greater than zero".to_string(), + )); + } + + let limit = args.limit.min(MAX_LIMIT); + let search_path = turn.resolve_path(args.path.clone()); + + verify_path_exists(&search_path).await?; + + let include = args.include.as_deref().map(str::trim).and_then(|val| { + if val.is_empty() { + None + } else { + Some(val.to_string()) + } + }); + + let search_results = + run_rg_search(pattern, include.as_deref(), &search_path, limit, &turn.cwd).await?; + + if search_results.is_empty() { + Ok(ToolOutput::Function { + content: "No matches found.".to_string(), + success: Some(false), + }) + } else { + Ok(ToolOutput::Function { + content: search_results.join("\n"), + success: Some(true), + }) + } + } +} + +async fn verify_path_exists(path: &Path) -> Result<(), FunctionCallError> { + tokio::fs::metadata(path).await.map_err(|err| { + FunctionCallError::RespondToModel(format!("unable to access `{}`: {err}", path.display())) + })?; + Ok(()) +} + +async fn run_rg_search( + pattern: &str, + include: Option<&str>, + search_path: &Path, + limit: usize, + cwd: &Path, +) -> Result, FunctionCallError> { + let mut command = Command::new("rg"); + command + .current_dir(cwd) + .arg("--files-with-matches") + .arg("--sortr=modified") + .arg("--regexp") + .arg(pattern) + .arg("--no-messages"); + + if let Some(glob) = include { + command.arg("--glob").arg(glob); + } + + command.arg("--").arg(search_path); + + let output = timeout(COMMAND_TIMEOUT, command.output()) + .await + .map_err(|_| { + FunctionCallError::RespondToModel("rg timed out after 30 seconds".to_string()) + })? + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to launch rg: {err}. Ensure ripgrep is installed and on PATH." + )) + })?; + + match output.status.code() { + Some(0) => Ok(parse_results(&output.stdout, limit)), + Some(1) => Ok(Vec::new()), + _ => { + let stderr = String::from_utf8_lossy(&output.stderr); + Err(FunctionCallError::RespondToModel(format!( + "rg failed: {stderr}" + ))) + } + } +} + +fn parse_results(stdout: &[u8], limit: usize) -> Vec { + let mut results = Vec::new(); + for line in stdout.split(|byte| *byte == b'\n') { + if line.is_empty() { + continue; + } + if let Ok(text) = std::str::from_utf8(line) { + if text.is_empty() { + continue; + } + results.push(text.to_string()); + if results.len() == limit { + break; + } + } + } + results +} + +#[cfg(test)] +mod tests { + use super::*; + use std::process::Command as StdCommand; + use tempfile::tempdir; + + #[test] + fn parses_basic_results() { + let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n"; + let parsed = parse_results(stdout, 10); + assert_eq!( + parsed, + vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] + ); + } + + #[test] + fn parse_truncates_after_limit() { + let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n"; + let parsed = parse_results(stdout, 2); + assert_eq!( + parsed, + vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] + ); + } + + #[tokio::test] + async fn run_search_returns_results() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap(); + std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); + std::fs::write(dir.join("other.txt"), "omega").unwrap(); + + let results = run_rg_search("alpha", None, dir, 10, dir).await?; + assert_eq!(results.len(), 2); + assert!(results.iter().any(|path| path.ends_with("match_one.txt"))); + assert!(results.iter().any(|path| path.ends_with("match_two.txt"))); + Ok(()) + } + + #[tokio::test] + async fn run_search_with_glob_filter() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap(); + std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); + + let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?; + assert_eq!(results.len(), 1); + assert!(results.iter().all(|path| path.ends_with("match_one.rs"))); + Ok(()) + } + + #[tokio::test] + async fn run_search_respects_limit() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("one.txt"), "alpha one").unwrap(); + std::fs::write(dir.join("two.txt"), "alpha two").unwrap(); + std::fs::write(dir.join("three.txt"), "alpha three").unwrap(); + + let results = run_rg_search("alpha", None, dir, 2, dir).await?; + assert_eq!(results.len(), 2); + Ok(()) + } + + #[tokio::test] + async fn run_search_handles_no_matches() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("one.txt"), "omega").unwrap(); + + let results = run_rg_search("alpha", None, dir, 5, dir).await?; + assert!(results.is_empty()); + Ok(()) + } + + fn rg_available() -> bool { + StdCommand::new("rg") + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) + } +} diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index d8cf29be72..9bff9fd5e7 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -1,5 +1,6 @@ pub mod apply_patch; mod exec_stream; +mod grep_files; mod list_dir; mod mcp; mod plan; @@ -13,6 +14,7 @@ pub use plan::PLAN_TOOL; pub use apply_patch::ApplyPatchHandler; pub use exec_stream::ExecStreamHandler; +pub use grep_files::GrepFilesHandler; pub use list_dir::ListDirHandler; pub use mcp::McpHandler; pub use plan::PlanHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 08c3f60f27..e38095ecc9 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -320,6 +320,56 @@ fn create_test_sync_tool() -> ToolSpec { }) } +fn create_grep_files_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "pattern".to_string(), + JsonSchema::String { + description: Some("Regular expression pattern to search for.".to_string()), + }, + ); + properties.insert( + "include".to_string(), + JsonSchema::String { + description: Some( + "Optional glob that limits which files are searched (e.g. \"*.rs\" or \ + \"*.{ts,tsx}\")." + .to_string(), + ), + }, + ); + properties.insert( + "path".to_string(), + JsonSchema::String { + description: Some( + "Directory or file path to search. Defaults to the session's working directory." + .to_string(), + ), + }, + ); + properties.insert( + "limit".to_string(), + JsonSchema::Number { + description: Some( + "Maximum number of file paths to return (defaults to 100).".to_string(), + ), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "grep_files".to_string(), + description: "Finds files whose contents match the pattern and lists them by modification \ + time." + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["pattern".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + fn create_read_file_tool() -> ToolSpec { let mut properties = BTreeMap::new(); properties.insert( @@ -610,6 +660,7 @@ pub(crate) fn build_specs( use crate::exec_command::create_write_stdin_tool_for_responses_api; use crate::tools::handlers::ApplyPatchHandler; use crate::tools::handlers::ExecStreamHandler; + use crate::tools::handlers::GrepFilesHandler; use crate::tools::handlers::ListDirHandler; use crate::tools::handlers::McpHandler; use crate::tools::handlers::PlanHandler; @@ -678,8 +729,16 @@ pub(crate) fn build_specs( if config .experimental_supported_tools - .iter() - .any(|tool| tool == "read_file") + .contains(&"grep_files".to_string()) + { + let grep_files_handler = Arc::new(GrepFilesHandler); + builder.push_spec_with_parallel_support(create_grep_files_tool(), true); + builder.register_handler("grep_files", grep_files_handler); + } + + if config + .experimental_supported_tools + .contains(&"read_file".to_string()) { let read_file_handler = Arc::new(ReadFileHandler); builder.push_spec_with_parallel_support(create_read_file_tool(), true); @@ -698,8 +757,7 @@ pub(crate) fn build_specs( if config .experimental_supported_tools - .iter() - .any(|tool| tool == "test_sync_tool") + .contains(&"test_sync_tool".to_string()) { let test_sync_handler = Arc::new(TestSyncHandler); builder.push_spec_with_parallel_support(create_test_sync_tool(), true); @@ -841,8 +899,9 @@ mod tests { let (tools, _) = build_specs(&config, None).build(); assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls); - assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls); + assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls); assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls); + assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls); } #[test] @@ -870,6 +929,11 @@ mod tests { .iter() .any(|tool| tool_name(&tool.spec) == "read_file") ); + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "grep_files") + ); assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir")); } diff --git a/codex-rs/core/tests/suite/grep_files.rs b/codex-rs/core/tests/suite/grep_files.rs new file mode 100644 index 0000000000..31195f7e3b --- /dev/null +++ b/codex-rs/core/tests/suite/grep_files.rs @@ -0,0 +1,237 @@ +#![cfg(not(target_os = "windows"))] + +use anyhow::Result; +use codex_core::model_family::find_family_for_model; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use serde_json::Value; +use std::collections::HashSet; +use std::path::Path; +use std::process::Command as StdCommand; +use wiremock::matchers::any; + +const MODEL_WITH_TOOL: &str = "test-gpt-5-codex"; + +fn ripgrep_available() -> bool { + StdCommand::new("rg") + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +macro_rules! skip_if_ripgrep_missing { + ($ret:expr $(,)?) => {{ + if !ripgrep_available() { + eprintln!("rg not available in PATH; skipping test"); + return $ret; + } + }}; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn grep_files_tool_collects_matches() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_ripgrep_missing!(Ok(())); + + let server = start_mock_server().await; + let test = build_test_codex(&server).await?; + + let search_dir = test.cwd.path().join("src"); + std::fs::create_dir_all(&search_dir)?; + let alpha = search_dir.join("alpha.rs"); + let beta = search_dir.join("beta.rs"); + let gamma = search_dir.join("gamma.txt"); + std::fs::write(&alpha, "alpha needle\n")?; + std::fs::write(&beta, "beta needle\n")?; + std::fs::write(&gamma, "needle in text but excluded\n")?; + + let call_id = "grep-files-collect"; + let arguments = serde_json::json!({ + "pattern": "needle", + "path": search_dir.to_string_lossy(), + "include": "*.rs", + }) + .to_string(); + + mount_tool_sequence(&server, call_id, &arguments, "grep_files").await; + submit_turn(&test, "please find uses of needle").await?; + + let bodies = recorded_bodies(&server).await?; + let tool_output = find_tool_output(&bodies, call_id).expect("tool output present"); + let payload = tool_output.get("output").expect("output field present"); + let (content_opt, success_opt) = extract_content_and_success(payload); + let content = content_opt.expect("content present"); + let success = success_opt.unwrap_or(true); + assert!(success, "expected success for matches, got {payload:?}"); + + let entries = collect_file_names(content); + assert_eq!(entries.len(), 2, "content: {content}"); + assert!( + entries.contains("alpha.rs"), + "missing alpha.rs in {entries:?}" + ); + assert!( + entries.contains("beta.rs"), + "missing beta.rs in {entries:?}" + ); + assert!( + !entries.contains("gamma.txt"), + "txt file should be filtered out: {entries:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn grep_files_tool_reports_empty_results() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_ripgrep_missing!(Ok(())); + + let server = start_mock_server().await; + let test = build_test_codex(&server).await?; + + let search_dir = test.cwd.path().join("logs"); + std::fs::create_dir_all(&search_dir)?; + std::fs::write(search_dir.join("output.txt"), "no hits here")?; + + let call_id = "grep-files-empty"; + let arguments = serde_json::json!({ + "pattern": "needle", + "path": search_dir.to_string_lossy(), + "limit": 5, + }) + .to_string(); + + mount_tool_sequence(&server, call_id, &arguments, "grep_files").await; + submit_turn(&test, "search again").await?; + + let bodies = recorded_bodies(&server).await?; + let tool_output = find_tool_output(&bodies, call_id).expect("tool output present"); + let payload = tool_output.get("output").expect("output field present"); + let (content_opt, success_opt) = extract_content_and_success(payload); + let content = content_opt.expect("content present"); + if let Some(success) = success_opt { + assert!(!success, "expected success=false payload: {payload:?}"); + } + assert_eq!(content, "No matches found."); + + Ok(()) +} + +#[allow(clippy::expect_used)] +async fn build_test_codex(server: &wiremock::MockServer) -> Result { + let mut builder = test_codex().with_config(|config| { + config.model = MODEL_WITH_TOOL.to_string(); + config.model_family = + find_family_for_model(MODEL_WITH_TOOL).expect("model family for test model"); + }); + builder.build(server).await +} + +async fn submit_turn(test: &TestCodex, prompt: &str) -> Result<()> { + let session_model = test.session_configured.model.clone(); + + test.codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: prompt.into(), + }], + final_output_json_schema: None, + cwd: test.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TaskComplete(_)) + }) + .await; + Ok(()) +} + +async fn mount_tool_sequence( + server: &wiremock::MockServer, + call_id: &str, + arguments: &str, + tool_name: &str, +) { + let first_response = sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, tool_name, arguments), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(server, any(), second_response).await; +} + +#[allow(clippy::expect_used)] +async fn recorded_bodies(server: &wiremock::MockServer) -> Result> { + let requests = server.received_requests().await.expect("requests recorded"); + Ok(requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect()) +} + +fn find_tool_output<'a>(requests: &'a [Value], call_id: &str) -> Option<&'a Value> { + requests.iter().find_map(|body| { + body.get("input") + .and_then(Value::as_array) + .and_then(|items| { + items.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + && item.get("call_id").and_then(Value::as_str) == Some(call_id) + }) + }) + }) +} + +fn collect_file_names(content: &str) -> HashSet { + content + .lines() + .filter_map(|line| { + if line.trim().is_empty() { + return None; + } + Path::new(line) + .file_name() + .map(|name| name.to_string_lossy().into_owned()) + }) + .collect() +} + +fn extract_content_and_success(value: &Value) -> (Option<&str>, Option) { + match value { + Value::String(text) => (Some(text.as_str()), None), + Value::Object(obj) => ( + obj.get("content").and_then(Value::as_str), + obj.get("success").and_then(Value::as_bool), + ), + _ => (None, None), + } +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 04df9423f5..6008811dea 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -9,6 +9,7 @@ mod compact_resume_fork; mod exec; mod exec_stream_events; mod fork_conversation; +mod grep_files; mod json_result; mod list_dir; mod live_cli; diff --git a/codex-rs/core/tests/suite/tool_parallelism.rs b/codex-rs/core/tests/suite/tool_parallelism.rs index 0923f2f068..b4e3d1c9ad 100644 --- a/codex-rs/core/tests/suite/tool_parallelism.rs +++ b/codex-rs/core/tests/suite/tool_parallelism.rs @@ -63,8 +63,9 @@ async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Re } fn assert_parallel_duration(actual: Duration) { + // Allow headroom for runtime overhead while still differentiating from serial execution. assert!( - actual < Duration::from_millis(500), + actual < Duration::from_millis(750), "expected parallel execution to finish quickly, got {actual:?}" ); }