Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use crate::realtime_conversation::handle_start as handle_realtime_conversation_s
use crate::realtime_conversation::handle_text as handle_realtime_conversation_text;
use crate::rollout::session_index;
use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::default_image_generation_output_dir;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
Expand Down Expand Up @@ -3327,6 +3328,14 @@ impl Session {
)
.into_text(),
);
if turn_context.features.enabled(Feature::ImageGeneration) {
let image_output_dir = default_image_generation_output_dir();
developer_sections.push(format!(
"Generated images are saved to {} as {} by default.",
image_output_dir.display(),
image_output_dir.join("<image_id>.png").display(),
));
}
if let Some(developer_instructions) = turn_context.developer_instructions.as_deref() {
developer_sections.push(developer_instructions.to_string());
}
Expand Down Expand Up @@ -6636,9 +6645,7 @@ async fn handle_assistant_item_done_in_plan_mode(
{
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;

if let Some(turn_item) =
handle_non_tool_response_item(item, true, Some(&turn_context.cwd)).await
{
if let Some(turn_item) = handle_non_tool_response_item(item, true).await {
emit_turn_item_in_plan_mode(
sess,
turn_context,
Expand Down Expand Up @@ -6818,9 +6825,7 @@ async fn try_run_sampling_request(
needs_follow_up |= output_result.needs_follow_up;
}
ResponseEvent::OutputItemAdded(item) => {
if let Some(turn_item) =
handle_non_tool_response_item(&item, plan_mode, Some(&turn_context.cwd)).await
{
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
let mut turn_item = turn_item;
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
let mut seeded_item_id: Option<String> = None;
Expand Down
39 changes: 39 additions & 0 deletions codex-rs/core/src/codex_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,45 @@ async fn build_initial_context_uses_previous_realtime_state() {
);
}

#[tokio::test]
async fn build_initial_context_describes_default_image_save_location() {
let (session, mut turn_context) = make_session_and_context().await;
turn_context
.features
.enable(Feature::ImageGeneration)
.expect("enable image generation feature");

let initial_context = session.build_initial_context(&turn_context).await;
let developer_texts = developer_input_texts(&initial_context);
let image_output_dir = crate::stream_events_utils::default_image_generation_output_dir();
let expected_text = format!(
"Generated images are saved to {} as {} by default.",
image_output_dir.display(),
image_output_dir.join("<image_id>.png").display(),
);
assert!(
developer_texts
.iter()
.any(|text| text.contains(expected_text.as_str())),
"expected initial context to describe the default image save location, got {developer_texts:?}"
);
}

#[tokio::test]
async fn build_initial_context_omits_default_image_save_location_when_disabled() {
let (session, turn_context) = make_session_and_context().await;

let initial_context = session.build_initial_context(&turn_context).await;
let developer_texts = developer_input_texts(&initial_context);

assert!(
!developer_texts
.iter()
.any(|text| text.contains("Generated images are saved to")),
"expected initial context to omit image save instructions when image generation is disabled, got {developer_texts:?}"
);
}

#[tokio::test]
async fn build_initial_context_uses_previous_turn_settings_for_realtime_end() {
let (session, turn_context) = make_session_and_context().await;
Expand Down
6 changes: 0 additions & 6 deletions codex-rs/core/src/context_manager/history_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,6 @@ fn for_prompt_rewrites_image_generation_calls_when_images_are_supported() {
ContentItem::InputImage {
image_url: "data:image/png;base64,Zm9v".to_string(),
},
ContentItem::InputText {
text: "Saved to: CWD".to_string(),
},
],
end_turn: None,
phase: None,
Expand Down Expand Up @@ -503,9 +500,6 @@ fn for_prompt_rewrites_image_generation_calls_when_images_are_unsupported() {
text: "image content omitted because you do not support image input"
.to_string(),
},
ContentItem::InputText {
text: "Saved to: CWD".to_string(),
},
],
end_turn: None,
phase: None,
Expand Down
3 changes: 0 additions & 3 deletions codex-rs/core/src/context_manager/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,6 @@ pub(crate) fn rewrite_image_generation_calls_for_stateless_input(items: &mut Vec
text: format!("Prompt: {revised_prompt}"),
},
ContentItem::InputImage { image_url },
ContentItem::InputText {
text: "Saved to: CWD".to_string(),
},
],
end_turn: None,
phase: None,
Expand Down
100 changes: 40 additions & 60 deletions codex-rs/core/src/stream_events_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -54,11 +53,7 @@ pub(crate) fn raw_assistant_output_text_from_item(item: &ResponseItem) -> Option
None
}

async fn save_image_generation_result_to_cwd(
cwd: &Path,
call_id: &str,
result: &str,
) -> Result<PathBuf> {
async fn save_image_generation_result(call_id: &str, result: &str) -> Result<PathBuf> {
let bytes = BASE64_STANDARD
.decode(result.trim().as_bytes())
.map_err(|err| {
Expand All @@ -77,11 +72,15 @@ async fn save_image_generation_result_to_cwd(
if file_stem.is_empty() {
file_stem = "generated_image".to_string();
}
let path = cwd.join(format!("{file_stem}.png"));
let path = default_image_generation_output_dir().join(format!("{file_stem}.png"));
tokio::fs::write(&path, bytes).await?;
Ok(path)
}

pub(crate) fn default_image_generation_output_dir() -> PathBuf {
std::env::temp_dir()
}

/// Persist a completed model response item and record any cited memory usage.
pub(crate) async fn record_completed_response_item(
sess: &Session,
Expand Down Expand Up @@ -189,9 +188,7 @@ pub(crate) async fn handle_output_item_done(
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
if let Some(turn_item) =
handle_non_tool_response_item(&item, plan_mode, Some(&ctx.turn_context.cwd)).await
{
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
if previously_active_item.is_none() {
let mut started_item = turn_item.clone();
if let TurnItem::ImageGeneration(item) = &mut started_item {
Expand Down Expand Up @@ -278,7 +275,6 @@ pub(crate) async fn handle_output_item_done(
pub(crate) async fn handle_non_tool_response_item(
item: &ResponseItem,
plan_mode: bool,
image_output_cwd: Option<&Path>,
) -> Option<TurnItem> {
debug!(?item, "Output item");

Expand All @@ -300,19 +296,16 @@ pub(crate) async fn handle_non_tool_response_item(
agent_message.content =
vec![codex_protocol::items::AgentMessageContent::Text { text: stripped }];
}
if let TurnItem::ImageGeneration(image_item) = &mut turn_item
&& let Some(cwd) = image_output_cwd
{
match save_image_generation_result_to_cwd(cwd, &image_item.id, &image_item.result)
.await
{
if let TurnItem::ImageGeneration(image_item) = &mut turn_item {
match save_image_generation_result(&image_item.id, &image_item.result).await {
Ok(path) => {
image_item.saved_path = Some(path.to_string_lossy().into_owned());
}
Err(err) => {
let output_dir = default_image_generation_output_dir();
tracing::warn!(
call_id = %image_item.id,
cwd = %cwd.display(),
output_dir = %output_dir.display(),
"failed to save generated image: {err}"
);
}
Expand Down Expand Up @@ -378,15 +371,15 @@ pub(crate) fn response_input_to_response_item(input: &ResponseInputItem) -> Opti

#[cfg(test)]
mod tests {
use super::default_image_generation_output_dir;
use super::handle_non_tool_response_item;
use super::last_assistant_message_from_item;
use super::save_image_generation_result_to_cwd;
use super::save_image_generation_result;
use crate::error::CodexErr;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use tempfile::tempdir;

fn assistant_output_text(text: &str) -> ResponseItem {
ResponseItem::Message {
Expand All @@ -404,10 +397,9 @@ mod tests {
async fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
let item = assistant_output_text("hello<oai-mem-citation>doc1</oai-mem-citation> world");

let turn_item =
handle_non_tool_response_item(&item, false, Some(std::path::Path::new(".")))
.await
.expect("assistant message should parse");
let turn_item = handle_non_tool_response_item(&item, false)
.await
.expect("assistant message should parse");

let TurnItem::AgentMessage(agent_message) = turn_item else {
panic!("expected agent message");
Expand Down Expand Up @@ -449,82 +441,70 @@ mod tests {
}

#[tokio::test]
async fn save_image_generation_result_saves_base64_to_png_in_cwd() {
let dir = tempdir().expect("tempdir");
async fn save_image_generation_result_saves_base64_to_png_in_temp_dir() {
let expected_path = default_image_generation_output_dir().join("ig_save_base64.png");
let _ = std::fs::remove_file(&expected_path);

let saved_path = save_image_generation_result_to_cwd(dir.path(), "ig_123", "Zm9v")
let saved_path = save_image_generation_result("ig_save_base64", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("ig_123.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
assert_eq!(saved_path, expected_path);
assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo");
let _ = std::fs::remove_file(&saved_path);
}

#[tokio::test]
async fn save_image_generation_result_rejects_data_url_payload() {
let dir = tempdir().expect("tempdir");
let result = "data:image/jpeg;base64,Zm9v";

let err = save_image_generation_result_to_cwd(dir.path(), "ig_456", result)
let err = save_image_generation_result("ig_456", result)
.await
.expect_err("data url payload should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}

#[tokio::test]
async fn save_image_generation_result_overwrites_existing_file() {
let dir = tempdir().expect("tempdir");
let existing_path = dir.path().join("ig_123.png");
let existing_path = default_image_generation_output_dir().join("ig_overwrite.png");
std::fs::write(&existing_path, b"existing").expect("seed existing image");

let saved_path = save_image_generation_result_to_cwd(dir.path(), "ig_123", "Zm9v")
let saved_path = save_image_generation_result("ig_overwrite", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("ig_123.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
assert_eq!(saved_path, existing_path);
assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo");
let _ = std::fs::remove_file(&saved_path);
}

#[tokio::test]
async fn save_image_generation_result_sanitizes_call_id_for_output_path() {
let dir = tempdir().expect("tempdir");
async fn save_image_generation_result_sanitizes_call_id_for_temp_dir_output_path() {
let expected_path = default_image_generation_output_dir().join("___ig___.png");
let _ = std::fs::remove_file(&expected_path);

let saved_path = save_image_generation_result_to_cwd(dir.path(), "../ig/..", "Zm9v")
let saved_path = save_image_generation_result("../ig/..", "Zm9v")
.await
.expect("image should be saved");

assert_eq!(saved_path.parent(), Some(dir.path()));
assert_eq!(
saved_path.file_name().and_then(|v| v.to_str()),
Some("___ig___.png")
);
assert_eq!(std::fs::read(saved_path).expect("saved file"), b"foo");
assert_eq!(saved_path, expected_path);
assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo");
let _ = std::fs::remove_file(&saved_path);
}

#[tokio::test]
async fn save_image_generation_result_rejects_non_standard_base64() {
let dir = tempdir().expect("tempdir");

let err = save_image_generation_result_to_cwd(dir.path(), "ig_urlsafe", "_-8")
let err = save_image_generation_result("ig_urlsafe", "_-8")
.await
.expect_err("non-standard base64 should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}

#[tokio::test]
async fn save_image_generation_result_rejects_non_base64_data_urls() {
let dir = tempdir().expect("tempdir");

let err =
save_image_generation_result_to_cwd(dir.path(), "ig_svg", "data:image/svg+xml,<svg/>")
.await
.expect_err("non-base64 data url should error");
let err = save_image_generation_result("ig_svg", "data:image/svg+xml,<svg/>")
.await
.expect_err("non-base64 data url should error");
assert!(matches!(err, CodexErr::InvalidRequest(_)));
}
}
21 changes: 13 additions & 8 deletions codex-rs/core/tests/suite/items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,14 @@ async fn image_generation_call_event_is_emitted() -> anyhow::Result<()> {

let server = start_mock_server().await;

let TestCodex { codex, cwd, .. } = test_codex().build(&server).await?;
let TestCodex { codex, .. } = test_codex().build(&server).await?;
let call_id = "ig_image_saved_to_temp_dir_default";
let expected_saved_path = std::env::temp_dir().join(format!("{call_id}.png"));
let _ = std::fs::remove_file(&expected_saved_path);

let first_response = sse(vec![
ev_response_created("resp-1"),
ev_image_generation_call("ig_123", "completed", "A tiny blue square", "Zm9v"),
ev_image_generation_call(call_id, "completed", "A tiny blue square", "Zm9v"),
ev_completed("resp-1"),
]);
mount_sse_once(&server, first_response).await;
Expand All @@ -299,17 +302,17 @@ async fn image_generation_call_event_is_emitted() -> anyhow::Result<()> {
})
.await;

assert_eq!(begin.call_id, "ig_123");
assert_eq!(end.call_id, "ig_123");
assert_eq!(begin.call_id, call_id);
assert_eq!(end.call_id, call_id);
assert_eq!(end.status, "completed");
assert_eq!(end.revised_prompt, Some("A tiny blue square".to_string()));
assert_eq!(end.result, "Zm9v");
let expected_saved_path = cwd.path().join("ig_123.png");
assert_eq!(
end.saved_path,
Some(expected_saved_path.to_string_lossy().into_owned())
);
assert_eq!(std::fs::read(expected_saved_path)?, b"foo");
assert_eq!(std::fs::read(&expected_saved_path)?, b"foo");
let _ = std::fs::remove_file(&expected_saved_path);

Ok(())
}
Expand All @@ -320,7 +323,9 @@ async fn image_generation_call_event_is_emitted_when_image_save_fails() -> anyho

let server = start_mock_server().await;

let TestCodex { codex, cwd, .. } = test_codex().build(&server).await?;
let TestCodex { codex, .. } = test_codex().build(&server).await?;
let expected_saved_path = std::env::temp_dir().join("ig_invalid.png");
let _ = std::fs::remove_file(&expected_saved_path);

let first_response = sse(vec![
ev_response_created("resp-1"),
Expand Down Expand Up @@ -356,7 +361,7 @@ async fn image_generation_call_event_is_emitted_when_image_save_fails() -> anyho
assert_eq!(end.revised_prompt, Some("broken payload".to_string()));
assert_eq!(end.result, "_-8");
assert_eq!(end.saved_path, None);
assert!(!cwd.path().join("ig_invalid.png").exists());
assert!(!expected_saved_path.exists());

Ok(())
}
Expand Down
Loading
Loading