Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions codex-rs/core/src/tools/handlers/apply_patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use async_trait::async_trait;
use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::ApplyPatchFileChange;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::sync::Arc;

pub struct ApplyPatchHandler;

Expand Down Expand Up @@ -139,8 +140,8 @@ impl ToolHandler for ApplyPatchHandler {
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
session: session.clone(),
turn: turn.clone(),
call_id: call_id.clone(),
tool_name: tool_name.to_string(),
};
Expand All @@ -149,7 +150,7 @@ impl ToolHandler for ApplyPatchHandler {
&mut runtime,
&req,
&tool_ctx,
&turn,
turn.as_ref(),
turn.approval_policy.value(),
)
.await
Expand Down Expand Up @@ -193,8 +194,8 @@ pub(crate) async fn intercept_apply_patch(
command: &[String],
cwd: &Path,
timeout_ms: Option<u64>,
session: &Session,
turn: &TurnContext,
session: Arc<Session>,
turn: Arc<TurnContext>,
tracker: Option<&SharedTurnDiffTracker>,
call_id: &str,
tool_name: &str,
Expand All @@ -203,11 +204,13 @@ pub(crate) async fn intercept_apply_patch(
codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
session
.record_model_warning(
format!("apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."),
turn,
format!(
"apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."
),
turn.as_ref(),
)
.await;
match apply_patch::apply_patch(turn, changes).await {
match apply_patch::apply_patch(turn.as_ref(), changes).await {
InternalApplyPatchInvocation::Output(item) => {
let content = item?;
Ok(Some(ToolOutput::Function {
Expand All @@ -219,8 +222,12 @@ pub(crate) async fn intercept_apply_patch(
let changes = convert_apply_patch_to_protocol(&apply.action);
let approval_keys = file_paths_for_action(&apply.action);
let emitter = ToolEmitter::apply_patch(changes.clone(), apply.auto_approved);
let event_ctx =
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
call_id,
tracker.as_ref().copied(),
);
emitter.begin(event_ctx).await;

let req = ApplyPatchRequest {
Expand All @@ -235,8 +242,8 @@ pub(crate) async fn intercept_apply_patch(
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session,
turn,
session: session.clone(),
turn: turn.clone(),
call_id: call_id.to_string(),
tool_name: tool_name.to_string(),
};
Expand All @@ -245,13 +252,17 @@ pub(crate) async fn intercept_apply_patch(
&mut runtime,
&req,
&tool_ctx,
turn,
turn.as_ref(),
turn.approval_policy.value(),
)
.await
.map(|result| result.output);
let event_ctx =
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
call_id,
tracker.as_ref().copied(),
);
let content = emitter.finish(event_ctx, out).await?;
Ok(Some(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
Expand Down
8 changes: 4 additions & 4 deletions codex-rs/core/src/tools/handlers/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ impl ShellHandler {
&exec_params.command,
&exec_params.cwd,
exec_params.expiration.timeout_ms(),
session.as_ref(),
turn.as_ref(),
session.clone(),
turn.clone(),
Some(&tracker),
&call_id,
tool_name.as_str(),
Expand Down Expand Up @@ -343,8 +343,8 @@ impl ShellHandler {
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ShellRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
session: session.clone(),
turn: turn.clone(),
call_id: call_id.clone(),
tool_name,
};
Expand Down
4 changes: 2 additions & 2 deletions codex-rs/core/src/tools/handlers/unified_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ impl ToolHandler for UnifiedExecHandler {
&command,
&cwd,
Some(yield_time_ms),
context.session.as_ref(),
context.turn.as_ref(),
context.session.clone(),
context.turn.clone(),
Some(&tracker),
&context.call_id,
tool_name.as_str(),
Expand Down
18 changes: 9 additions & 9 deletions codex-rs/core/src/tools/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ impl ToolOrchestrator {
async fn run_attempt<Rq, Out, T>(
tool: &mut T,
req: &Rq,
tool_ctx: &ToolCtx<'_>,
tool_ctx: &ToolCtx,
attempt: &SandboxAttempt<'_>,
has_managed_network_requirements: bool,
) -> (Result<Out, ToolError>, Option<DeferredNetworkApproval>)
where
T: ToolRuntime<Rq, Out>,
{
let network_approval = begin_network_approval(
tool_ctx.session,
&tool_ctx.session,
&tool_ctx.turn.sub_id,
&tool_ctx.call_id,
has_managed_network_requirements,
Expand All @@ -65,8 +65,8 @@ impl ToolOrchestrator {
.await;

let attempt_tool_ctx = ToolCtx {
session: tool_ctx.session,
turn: tool_ctx.turn,
session: tool_ctx.session.clone(),
turn: tool_ctx.turn.clone(),
call_id: tool_ctx.call_id.clone(),
tool_name: tool_ctx.tool_name.clone(),
};
Expand All @@ -79,7 +79,7 @@ impl ToolOrchestrator {
match network_approval.mode() {
NetworkApprovalMode::Immediate => {
let finalize_result =
finish_immediate_network_approval(tool_ctx.session, network_approval).await;
finish_immediate_network_approval(&tool_ctx.session, network_approval).await;
if let Err(err) = finalize_result {
return (Err(err), None);
}
Expand All @@ -88,7 +88,7 @@ impl ToolOrchestrator {
NetworkApprovalMode::Deferred => {
let deferred = network_approval.into_deferred();
if run_result.is_err() {
finish_deferred_network_approval(tool_ctx.session, deferred).await;
finish_deferred_network_approval(&tool_ctx.session, deferred).await;
return (run_result, None);
}
(run_result, deferred)
Expand All @@ -100,7 +100,7 @@ impl ToolOrchestrator {
&mut self,
tool: &mut T,
req: &Rq,
tool_ctx: &ToolCtx<'_>,
tool_ctx: &ToolCtx,
turn_ctx: &crate::codex::TurnContext,
approval_policy: AskForApproval,
) -> Result<OrchestratorRunResult<Out>, ToolError>
Expand Down Expand Up @@ -128,7 +128,7 @@ impl ToolOrchestrator {
}
ExecApprovalRequirement::NeedsApproval { reason, .. } => {
let approval_ctx = ApprovalCtx {
session: tool_ctx.session,
session: &tool_ctx.session,
turn: turn_ctx,
call_id: &tool_ctx.call_id,
retry_reason: reason,
Expand Down Expand Up @@ -256,7 +256,7 @@ impl ToolOrchestrator {
&& network_approval_context.is_none();
if !bypass_retry_approval {
let approval_ctx = ApprovalCtx {
session: tool_ctx.session,
session: &tool_ctx.session,
turn: turn_ctx,
call_id: &tool_ctx.call_id,
retry_reason: Some(retry_reason),
Expand Down
4 changes: 2 additions & 2 deletions codex-rs/core/src/tools/runtimes/apply_patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl ApplyPatchRuntime {
})
}

fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
Some(crate::exec::StdoutStream {
sub_id: ctx.turn.sub_id.clone(),
call_id: ctx.call_id.clone(),
Expand Down Expand Up @@ -156,7 +156,7 @@ impl ToolRuntime<ApplyPatchRequest, ExecToolCallOutput> for ApplyPatchRuntime {
&mut self,
req: &ApplyPatchRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<ExecToolCallOutput, ToolError> {
let spec = Self::build_command_spec(req)?;
let env = attempt
Expand Down
8 changes: 4 additions & 4 deletions codex-rs/core/src/tools/runtimes/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl ShellRuntime {
Self
}

fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
Some(crate::exec::StdoutStream {
sub_id: ctx.turn.sub_id.clone(),
call_id: ctx.call_id.clone(),
Expand Down Expand Up @@ -150,7 +150,7 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
fn network_approval_spec(
&self,
req: &ShellRequest,
_ctx: &ToolCtx<'_>,
_ctx: &ToolCtx,
) -> Option<NetworkApprovalSpec> {
req.network.as_ref()?;
Some(NetworkApprovalSpec {
Expand All @@ -163,7 +163,7 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
&mut self,
req: &ShellRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<ExecToolCallOutput, ToolError> {
let base_command = &req.command;
let session_shell = ctx.session.user_shell();
Expand Down Expand Up @@ -207,7 +207,7 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
.session
.services
.zsh_exec_bridge
.execute_shell_request(&env, ctx.session, ctx.turn, &ctx.call_id)
.execute_shell_request(&env, &ctx.session, &ctx.turn, &ctx.call_id)
.await;
}

Expand Down
4 changes: 2 additions & 2 deletions codex-rs/core/src/tools/runtimes/unified_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
fn network_approval_spec(
&self,
req: &UnifiedExecRequest,
_ctx: &ToolCtx<'_>,
_ctx: &ToolCtx,
) -> Option<NetworkApprovalSpec> {
req.network.as_ref()?;
Some(NetworkApprovalSpec {
Expand All @@ -166,7 +166,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
&mut self,
req: &UnifiedExecRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<UnifiedExecProcess, ToolError> {
let base_command = &req.command;
let session_shell = ctx.session.user_shell();
Expand Down
16 changes: 8 additions & 8 deletions codex-rs/core/src/tools/sandboxing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ use codex_protocol::approvals::ExecPolicyAmendment;
use codex_protocol::approvals::NetworkApprovalContext;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::ReviewDecision;
use futures::Future;
use futures::future::BoxFuture;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::path::Path;

use futures::Future;
use futures::future::BoxFuture;
use serde::Serialize;
use std::sync::Arc;

#[derive(Clone, Default, Debug)]
pub(crate) struct ApprovalStore {
Expand Down Expand Up @@ -267,9 +267,9 @@ pub(crate) trait Sandboxable {
}
}

pub(crate) struct ToolCtx<'a> {
pub session: &'a Session,
pub turn: &'a TurnContext,
pub(crate) struct ToolCtx {
pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub call_id: String,
pub tool_name: String,
}
Expand All @@ -281,7 +281,7 @@ pub(crate) enum ToolError {
}

pub(crate) trait ToolRuntime<Req, Out>: Approvable<Req> + Sandboxable {
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx<'_>) -> Option<NetworkApprovalSpec> {
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx) -> Option<NetworkApprovalSpec> {
None
}

Expand Down
6 changes: 3 additions & 3 deletions codex-rs/core/src/unified_exec/process_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,8 @@ impl UnifiedExecProcessManager {
exec_approval_requirement,
};
let tool_ctx = ToolCtx {
session: context.session.as_ref(),
turn: context.turn.as_ref(),
session: context.session.clone(),
turn: context.turn.clone(),
call_id: context.call_id.clone(),
tool_name: "exec_command".to_string(),
};
Expand All @@ -604,7 +604,7 @@ impl UnifiedExecProcessManager {
&mut runtime,
&req,
&tool_ctx,
context.turn.as_ref(),
&context.turn,
context.turn.approval_policy.value(),
)
.await
Expand Down
Loading