diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index a33385f5e6..0ca67627ba 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -9,6 +9,7 @@ use internal_baml_core::ir::repr::IntermediateRepr; use jsonish::BamlValueWithFlags; use serde_json::json; use stream_cancel::Tripwire; +use tokio::sync::{watch, Mutex}; #[cfg(not(target_family = "wasm"))] use tokio::time::*; #[cfg(target_family = "wasm")] @@ -22,7 +23,7 @@ use crate::{ orchestrator::ExecutionScope, parsed_value_to_response, traits::{HttpContext, WithClientProperties, WithPrompt, WithStreamable}, - LLMErrorResponse, LLMResponse, ResponseBamlValue, + ErrorCode, LLMCompleteResponse, LLMErrorResponse, LLMResponse, ResponseBamlValue, }, prompt_renderer::PromptRenderer, }, @@ -30,6 +31,121 @@ use crate::{ FunctionResult, RuntimeContext, }; +// Shared state between the SSE consumer and the throttled parser. +#[derive(Default)] +struct ParserState { + last_sent_partial_serialized: Option, + last_processed_snapshot_ptr: Option, +} + +// Attempts to parse the latest SSE snapshot. We split this out in case parsing takes longer than the SSE interval. +async fn run_parser_loop<'a, ParseFn, EventFn>( + scope: OrchestrationScope, + parse_state: Arc>, + partial_parse_fn: &'a ParseFn, + on_event: &'a EventFn, + mut snapshot_rx: watch::Receiver>>, +) where + ParseFn: Fn(&str) -> Result + 'a, + EventFn: Fn(FunctionResult) + 'a, +{ + let mut parse_interval = interval(web_time::Duration::from_millis(50)); + parse_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + + loop { + tokio::select! { + _ = parse_interval.tick() => { + process_latest_snapshot( + &scope, + &parse_state, + partial_parse_fn, + on_event, + &mut snapshot_rx, + ).await; + } + changed = snapshot_rx.changed() => { + if changed.is_err() { + process_latest_snapshot( + &scope, + &parse_state, + partial_parse_fn, + on_event, + &mut snapshot_rx, + ).await; + break; + } + } + } + } +} + +async fn process_latest_snapshot<'a, ParseFn, EventFn>( + scope: &OrchestrationScope, + parse_state: &Arc>, + partial_parse_fn: &'a ParseFn, + on_event: &'a EventFn, + snapshot_rx: &mut watch::Receiver>>, +) where + ParseFn: Fn(&str) -> Result + 'a, + EventFn: Fn(FunctionResult) + 'a, +{ + let Some(snapshot) = snapshot_rx.borrow().clone() else { + return; + }; + + let snapshot_ptr = Arc::as_ptr(&snapshot) as usize; + let should_attempt = { + let state = parse_state.lock().await; + state.last_processed_snapshot_ptr != Some(snapshot_ptr) + }; + + if !should_attempt { + return; + } + + match partial_parse_fn(&snapshot.content) { + Ok(baml_value) => { + let parsed = ResponseBamlValue( + baml_value + .0 + .map_meta_owned(|m| jsonish::ResponseValueMeta(vec![], m.1, m.2, m.3)), + ); + let partial = parsed.serialize_partial(); + let serialized = serde_json::to_string(&partial).ok(); + + let should_emit = { + let mut state = parse_state.lock().await; + let should_emit = match serialized.as_ref() { + Some(serialized_str) => { + state.last_sent_partial_serialized.as_deref() + != Some(serialized_str.as_str()) + } + None => true, + }; + state.last_processed_snapshot_ptr = Some(snapshot_ptr); + if should_emit { + if let Some(serialized_str) = serialized.clone() { + state.last_sent_partial_serialized = Some(serialized_str); + } + } + should_emit + }; + + if should_emit { + on_event(FunctionResult::new( + scope.clone(), + LLMResponse::Success((*snapshot).clone()), + Some(Ok(parsed)), + )); + } + } + Err(err) => { + let mut state = parse_state.lock().await; + state.last_processed_snapshot_ptr = None; + } + } +} + pub async fn orchestrate_stream( iter: OrchestratorNodeIterator, ir: &IntermediateRepr, @@ -103,126 +219,64 @@ where let final_response = match stream_res { Ok(mut response_stream) => { let mut last_response: Option = None; - let mut latest_success_snapshot: Option = None; - let mut latest_content_for_parse: Option = None; - // Track last parsed payload surfaced to downstream listeners so we can dedupe events - let mut last_sent_partial_serialized: Option = None; - let mut parse_interval = interval(web_time::Duration::from_millis(20)); - // If parsing falls behind, skip missed ticks so we only parse latest. - parse_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + let parse_state = Arc::new(Mutex::new(ParserState::default())); + let (snapshot_tx, snapshot_rx) = watch::channel::>>(None); - loop { - tokio::select! { - // Prioritize consuming SSE events over parsing. - biased; - maybe_item = FuturesStreamExt::next(&mut response_stream) => { - match maybe_item { - Some(stream_part) => { - if let Some(on_tick) = on_tick_fn.as_ref() { - on_tick(); - } - match &stream_part { - LLMResponse::Success(s) => { - // Track latest snapshot and content - latest_success_snapshot = Some(s.clone()); - latest_content_for_parse = Some(s.content.clone()); - last_response = Some(LLMResponse::Success(s.clone())); - } - other => { - last_response = Some(other.clone()); - } - } - } - None => { - // End of stream - break; - } - } - } - // Periodically surface the latest partial parse to downstream listeners. - _ = parse_interval.tick(), if on_event.is_some() => { - if let Some(on_event) = on_event.as_ref() { - if let Some(snap) = latest_success_snapshot.as_ref() { - if let Some(mut content) = latest_content_for_parse.take() { - match partial_parse_fn(&content) { - Ok(baml_value) => { - // Strip flags to reduce memory usage - let parsed = ResponseBamlValue(baml_value.0.map_meta_owned(|m| { - jsonish::ResponseValueMeta(vec![], m.1, m.2, m.3) - })); - if let Ok(serialized) = serde_json::to_string(&parsed.serialize_partial()) { - if last_sent_partial_serialized - .as_deref() - != Some(serialized.as_str()) - { - // only successful events sent to the client - on_event(FunctionResult::new( - node.scope.clone(), - LLMResponse::Success(snap.clone()), - Some(Ok(parsed)), - )); - last_sent_partial_serialized = Some(serialized); - } - } else { - // If serialization fails, still emit the parsed event instead of dropping it. - on_event(FunctionResult::new( - node.scope.clone(), - LLMResponse::Success(snap.clone()), - Some(Ok(parsed)), - )); - // Intentionally do not update last_sent_partial_serialized here. - } - } - Err(_) => { - // Only restore the content if nothing newer has arrived since we took it. - if latest_content_for_parse.is_none() { - latest_content_for_parse = Some(content); - } - } - } - } - } - } - } + let parser_future = on_event.as_ref().map(|on_event_cb| { + let scope = node.scope.clone(); + let parse_state = parse_state.clone(); + let partial_parse_fn = &partial_parse_fn; + let snapshot_rx = snapshot_rx.clone(); + async move { + run_parser_loop( + scope, + parse_state, + partial_parse_fn, + on_event_cb, + snapshot_rx, + ) + .await; } - } + }); + + let on_tick_cb = on_tick_fn.as_ref(); + let parse_state_for_sse = parse_state.clone(); + let sse_future = async move { + let snapshot_sender = snapshot_tx; + while let Some(stream_part) = FuturesStreamExt::next(&mut response_stream).await { + if let Some(on_tick) = on_tick_cb { + on_tick(); + } - if let Some(on_event) = on_event.as_ref() { - if let Some(snap) = latest_success_snapshot.as_ref() { - if let Some(mut content) = latest_content_for_parse.take() { - if let Ok(baml_value) = partial_parse_fn(&content) { - // Strip flags to reduce memory usage - let parsed = ResponseBamlValue(baml_value.0.map_meta_owned(|m| { - jsonish::ResponseValueMeta(vec![], m.1, m.2, m.3) - })); - if let Ok(serialized) = serde_json::to_string(&parsed.serialize_partial()) { - if last_sent_partial_serialized - .as_deref() - != Some(serialized.as_str()) - { - // Only successful events should reach downstream listeners - on_event(FunctionResult::new( - node.scope.clone(), - LLMResponse::Success(snap.clone()), - Some(Ok(parsed)), - )); - last_sent_partial_serialized = Some(serialized); - } - } else { - // If serialization fails, still emit the parsed event instead of dropping it. - on_event(FunctionResult::new( - node.scope.clone(), - LLMResponse::Success(snap.clone()), - Some(Ok(parsed)), - )); - // Intentionally do not update last_sent_partial_serialized here. - } + match &stream_part { + LLMResponse::Success(s) => { + let snapshot = Arc::new(s.clone()); + let _ = snapshot_sender.send_replace(Some(snapshot.clone())); + last_response = Some(LLMResponse::Success((*snapshot).clone())); + + let mut state = parse_state_for_sse.lock().await; + state.last_processed_snapshot_ptr = None; + } + other => { + last_response = Some(other.clone()); } } } - } - last_response.unwrap_or_else(|| { + drop(snapshot_sender); + last_response + }; + + let final_last_response = if let Some(parser_future) = parser_future { + let (last_response_opt, _) = futures::future::join(sse_future, parser_future).await; + last_response_opt + } else { + sse_future.await + }; + + if let Some(response) = final_last_response { + response + } else { LLMResponse::LLMFailure(LLMErrorResponse { client: node.provider.name().into(), model: None, @@ -230,10 +284,10 @@ where start_time: system_start, latency: instant_start.elapsed(), request_options: node.provider.request_options().clone(), - message: "Stream ended without response".to_string(), - code: crate::internal::llm_client::ErrorCode::from_u16(2), + message: "Stream ended and no events were received".to_string(), + code: ErrorCode::Other(2), }) - }) + } } Err(response) => response, }; diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs index 3ca23964fb..8b59839252 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs @@ -147,12 +147,23 @@ impl ProviderStrategy { .iter() .map(|part| match part { ChatMessagePart::Text(text) => { + let content_type = if msg.role == "assistant" { + "output_text" + } else { + "input_text" + }; Ok(json!({ - "type": if msg.role == "assistant" { "output_text" } else { "input_text" }, + "type": content_type, "text": text })) } ChatMessagePart::Media(media) => { + // For assistant role, we only support text outputs in Responses API + if msg.role == "assistant" { + anyhow::bail!( + "BAML internal error (openai-responses): assistant messages must be text; media not supported for assistant in Responses API" + ); + } match media.media_type { baml_types::BamlMediaType::Image => { let image_url = match &media.content { @@ -165,7 +176,8 @@ impl ProviderStrategy { } }; Ok(json!({ - "type": if msg.role == "assistant" { "output_image" } else { "input_image" }, + "type": "input_image", + "detail": "auto", "image_url": image_url })) } @@ -177,7 +189,7 @@ impl ProviderStrategy { .strip_prefix("audio/") .unwrap_or(&mime_type); Ok(json!({ - "type": if msg.role == "assistant" { "output_audio" } else { "input_audio" }, + "type": "input_audio", "input_audio": { "data": b64_media.base64, "format": format @@ -193,8 +205,9 @@ impl ProviderStrategy { match &media.content { baml_types::BamlMediaContent::Url(url_content) => { Ok(json!({ - "type": if msg.role == "assistant" { "output_file" } else { "input_file" }, - "file_url": url_content.url + "type": "input_file", + "file_url": url_content.url, + "filename": "document.pdf" })) } baml_types::BamlMediaContent::File(file_content) => { @@ -202,8 +215,9 @@ impl ProviderStrategy { } baml_types::BamlMediaContent::Base64(b64_media) => { Ok(json!({ - "type": if msg.role == "assistant" { "output_file" } else { "input_file" }, - "file_data": format!("data:{};base64,{}", media.mime_type_as_ok()?, b64_media.base64) + "type": "input_file", + "file_data": format!("data:{};base64,{}", media.mime_type_as_ok()?, b64_media.base64), + "filename": "document.pdf" })) } } @@ -217,13 +231,23 @@ impl ProviderStrategy { // Recursively handle the inner part, ignoring metadata for now match inner_part.as_ref() { ChatMessagePart::Text(text) => { + let content_type = if msg.role == "assistant" { + "output_text" + } else { + "input_text" + }; Ok(json!({ - "type": if msg.role == "assistant" { "output_text" } else { "input_text" }, + "type": content_type, "text": text })) } ChatMessagePart::Media(media) => { // Handle media same as above - could refactor into helper function + if msg.role == "assistant" { + anyhow::bail!( + "BAML internal error (openai-responses): assistant messages must be text; media not supported for assistant in Responses API" + ); + } match media.media_type { baml_types::BamlMediaType::Image => { let image_url = match &media.content { @@ -236,7 +260,8 @@ impl ProviderStrategy { } }; Ok(json!({ - "type": if msg.role == "assistant" { "output_image" } else { "input_image" }, + "type": "input_image", + "detail": "auto", "image_url": image_url })) } @@ -698,7 +723,8 @@ impl ToProviderMessage for OpenAIClient { payload_key.into(), json!({ "type": "input_file", - "file_url": url_content.url + "file_url": url_content.url, + "filename": "document.pdf" }), ); } @@ -916,4 +942,102 @@ mod tests { let endpoint = strategy.get_endpoint("https://api.openai.com/v1", true); assert_eq!(endpoint, "https://api.openai.com/v1/completions"); } + + #[test] + fn test_responses_api_builds_input_message_with_text_and_file() { + let strategy = ProviderStrategy::ResponsesApi; + + // Properties include model + let mut props = BamlMap::new(); + props.insert("model".into(), json!("gpt-5-mini")); + + // Build a user message with text and file (PDF url) + let msg = RenderedChatMessage { + role: "user".to_string(), + allow_duplicate_role: false, + parts: vec![ + ChatMessagePart::Text("what is in this file?".to_string()), + ChatMessagePart::Media(baml_types::BamlMedia::url( + BamlMediaType::Pdf, + "https://www.berkshirehathaway.com/letters/2024ltr.pdf".to_string(), + Some("application/pdf".to_string()), + )), + ], + }; + + // chat_converter is not used in ResponsesApi branch; construct a minimal client + let responses_client = OpenAIClient { + name: "test".to_string(), + provider: "openai-responses".to_string(), + retry_policy: None, + context: RenderContext_Client { + name: "test".to_string(), + provider: "openai-responses".to_string(), + default_role: "user".to_string(), + allowed_roles: vec!["user".to_string(), "assistant".to_string()], + remap_role: HashMap::new(), + options: IndexMap::new(), + }, + features: ModelFeatures { + chat: true, + completion: false, + max_one_system_prompt: false, + resolve_audio_urls: ResolveMediaUrls::Always, + resolve_image_urls: ResolveMediaUrls::Never, + resolve_pdf_urls: ResolveMediaUrls::Never, + resolve_video_urls: ResolveMediaUrls::Never, + allowed_metadata: AllowedRoleMetadata::All, + }, + properties: ResolvedOpenAI { + base_url: "https://api.openai.com/v1".to_string(), + api_key: None, + role_selection: RolesSelection::default(), + allowed_metadata: AllowedRoleMetadata::All, + supported_request_modes: SupportedRequestModes::default(), + headers: IndexMap::new(), + properties: BamlMap::new(), + query_params: IndexMap::new(), + proxy_url: None, + finish_reason_filter: FinishReasonFilter::All, + client_response_type: ResponseType::OpenAIResponses, + }, + client: reqwest::Client::new(), + }; + + let body_value = strategy + .build_body(either::Either::Right(&[msg]), &props, &responses_client) + .expect("should build body"); + + let obj = body_value.as_object().expect("body should be an object"); + assert_eq!(obj.get("model"), Some(&json!("gpt-5-mini"))); + + let input = obj + .get("input") + .and_then(|v| v.as_array()) + .expect("input should be array"); + assert_eq!(input.len(), 1); + + let first_msg = input[0].as_object().expect("message should be object"); + assert_eq!(first_msg.get("role"), Some(&json!("user"))); + let content = first_msg + .get("content") + .and_then(|v| v.as_array()) + .expect("content should be array"); + assert_eq!(content.len(), 2); + + // Validate text part + let t = content[0].as_object().expect("text part object"); + assert_eq!(t.get("type"), Some(&json!("input_text"))); + assert_eq!(t.get("text"), Some(&json!("what is in this file?"))); + + // Validate file part + let f = content[1].as_object().expect("file part object"); + assert_eq!(f.get("type"), Some(&json!("input_file"))); + assert_eq!( + f.get("file_url"), + Some(&json!( + "https://www.berkshirehathaway.com/letters/2024ltr.pdf" + )) + ); + } } diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/stream_request.rs b/engine/baml-runtime/src/internal/llm_client/primitive/stream_request.rs index 51bf1a5f25..dc2f9b3325 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/stream_request.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/stream_request.rs @@ -92,7 +92,7 @@ pub async fn make_stream_request( std::future::ready(event.as_ref().is_ok_and(|e| e.data != "[DONE]")) }) .map(|event| -> Result { Ok(serde_json::from_str(&event?.data)?) }) - .inspect(|event| log::trace!("{event:#?}")) + .inspect(|event| log::debug!("{event:#?}")) .scan( Ok(LLMCompleteResponse { client: client_name.clone(), diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/response_handler.rs b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/response_handler.rs index cccc42fa76..24a9835519 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/response_handler.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/response_handler.rs @@ -125,7 +125,9 @@ pub fn scan_vertex_response_stream( let inner = match accumulated { Ok(accumulated) => accumulated, // We'll just keep the first error and return it - Err(e) => return Ok(()), + Err(e) => { + return Ok(()); + } }; let event = VertexResponse::deserialize(&event_body) diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs index 1630024e69..fd2a856877 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs @@ -132,7 +132,10 @@ impl WithStreamChat for VertexClient { self, either::Either::Right(prompt), Some(self.properties.model.clone()), - ResponseType::Vertex, + match self.properties.anthropic_version { + Some(ref anthropic_version) => ResponseType::Anthropic, + None => ResponseType::Vertex, + }, ctx, ) .await @@ -350,6 +353,13 @@ impl RequestBuilder for VertexClient { } } + // If this is an Anthropic-on-Vertex request and streaming is enabled, add `stream: true` + // to the JSON body to mirror Anthropic API behavior. + // See docs here: https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-5-sonnet?authuser=1&hl=en&project=gloo-ai + if stream && self.properties.anthropic_version.is_some() { + json_body.insert("stream".into(), json!(true)); + } + let req = req.json(&json_body); Ok(req) diff --git a/engine/baml-runtime/src/tracingv2/publisher/publisher.rs b/engine/baml-runtime/src/tracingv2/publisher/publisher.rs index d9dbc2c0fe..1bbaea87ae 100644 --- a/engine/baml-runtime/src/tracingv2/publisher/publisher.rs +++ b/engine/baml-runtime/src/tracingv2/publisher/publisher.rs @@ -547,7 +547,7 @@ impl TracePublisher { { Ok(response) => response, Err(e) => { - tracing::error!("Failed to check BAML source upload status: {}", e); + tracing::warn!("Failed to check BAML source upload status: {}", e); return Err(e.into()); } }; diff --git a/integ-tests/python/tests/providers/test_openai_responses.py b/integ-tests/python/tests/providers/test_openai_responses.py index 3222f134e8..b2a70779d6 100644 --- a/integ-tests/python/tests/providers/test_openai_responses.py +++ b/integ-tests/python/tests/providers/test_openai_responses.py @@ -24,6 +24,7 @@ async def test_expose_request_openai_responses_multimodal(): { "type": "input_image", "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "detail": "auto", }, ], } @@ -84,6 +85,7 @@ async def test_expose_request_openai_responses_pdf_base64(): { "type": "input_file", "file_data": f"data:application/pdf;base64,{test_pdf_b64}", + "filename": "document.pdf", }, ], } @@ -266,6 +268,7 @@ async def test_expose_request_openai_responses_pdf_url(): { "type": "input_file", "file_url": "https://www.usenix.org/system/files/conference/nsdi13/nsdi13-final85.pdf", + "filename": "document.pdf", }, ], } diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 11892de2ee..0861e039d2 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -624,10 +624,29 @@ async def test_anthropic_shorthand(): @pytest.mark.asyncio async def test_anthropic_shorthand_streaming(): - res = await b.stream.TestAnthropicShorthand( - input="Mt Rainier is tall" - ).get_final_response() - assert len(res) > 0, "Expected non-empty result but got empty." + res = b.stream.TestAnthropicShorthand(input="Mt Rainier is tall") + chunks = [] + async for chunk in res: + chunks.append(chunk) + print("chunk", chunk) + final = await res.get_final_response() + print("final", final) + + assert len(chunks) > 0, "Expected non-empty result but got empty." + assert len(final) > 0, "Expected non-empty result but got empty." + + +@pytest.mark.asyncio +async def test_vertex_anthropic_streaming(): + res = b.stream.TestVertexClaude(input="Mt Rainier is tall") + chunks = [] + async for chunk in res: + chunks.append(chunk) + print("chunk", chunk) + final = await res.get_final_response() + print("final", final) + assert len(chunks) > 0, "Expected non-empty result but got empty." + assert len(final) > 0, "Expected non-empty result but got empty." @pytest.mark.asyncio @@ -639,7 +658,7 @@ async def test_fallback_to_shorthand(): @pytest.mark.asyncio -async def test_streaming(): +async def test_streaming_long(): stream = b.stream.PromptTestStreaming( input="Programming languages are fun to create" ) @@ -658,21 +677,20 @@ async def test_streaming(): final = await stream.get_final_response() - assert first_msg_time - start_time <= 1.5, ( - "Expected first message within 1 second but it took longer." - ) - assert last_msg_time - start_time >= 1, ( - "Expected last message after 1.5 seconds but it was earlier." - ) + assert ( + first_msg_time - start_time <= 1.5 + ), "Expected first message within 1 second but it took longer." + assert ( + last_msg_time - start_time >= 1 + ), "Expected last message after 1.5 seconds but it was earlier." assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) assert msgs[-1] == final, "Expected last stream message to match final response." @@ -713,12 +731,11 @@ def test_streaming_sync(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 5, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) assert msgs[-1] == final, "Expected last stream message to match final response." @@ -741,12 +758,11 @@ async def test_streaming_claude(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) print("msgs:") print(msgs[-1]) @@ -767,12 +783,11 @@ async def test_streaming_gemini(): assert len(final) > 0, "Expected non-empty final but got empty." assert len(msgs) > 0, "Expected at least one streamed response but got none." for prev_msg, msg in zip(msgs, msgs[1:]): - assert msg.startswith(prev_msg), ( - "Expected messages to be continuous, but prev was %r and next was %r" - % ( - prev_msg, - msg, - ) + assert msg.startswith( + prev_msg + ), "Expected messages to be continuous, but prev was %r and next was %r" % ( + prev_msg, + msg, ) print("msgs:") print(msgs[-1]) @@ -1180,9 +1195,9 @@ async def test_caching(): print("Duration no caching: ", duration) print("Duration with caching: ", duration2) - assert duration2 < duration, ( - f"{duration2} < {duration}. Expected second call to be faster than first by a large margin." - ) + assert ( + duration2 < duration + ), f"{duration2} < {duration}. Expected second call to be faster than first by a large margin." @pytest.mark.asyncio @@ -1250,9 +1265,9 @@ async def test_baml_validation_error_format(): except errors.BamlValidationError as e: print("Error: ", e) assert hasattr(e, "prompt"), "Error object should have 'prompt' attribute" - assert hasattr(e, "raw_output"), ( - "Error object should have 'raw_output' attribute" - ) + assert hasattr( + e, "raw_output" + ), "Error object should have 'raw_output' attribute" assert hasattr(e, "message"), "Error object should have 'message' attribute" assert 'Say "hello there"' in e.prompt @@ -1563,9 +1578,9 @@ async def test_gemini_thinking(): except Exception as e: # If it fails with the thinking config, ensure it's not due to parsing multiple non-thought parts - assert "Too many matches" not in str(e), ( - f"Parsing error with thinking response: {e}" - ) + assert "Too many matches" not in str( + e + ), f"Parsing error with thinking response: {e}" raise diff --git a/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-preview-curl.tsx b/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-preview-curl.tsx index 671ae5e966..4f51882ba8 100644 --- a/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-preview-curl.tsx +++ b/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-preview-curl.tsx @@ -3,7 +3,7 @@ import { useAtomValue } from 'jotai'; import { atom } from 'jotai'; import { loadable } from 'jotai/utils'; import { useTheme } from 'next-themes'; -import { useEffect, useState } from 'react'; +import { useEffect, useState, memo } from 'react'; import type React from 'react'; import { useMemo } from 'react'; import { apiKeysAtom } from '../../../../components/api-keys-dialog/atoms'; @@ -76,10 +76,10 @@ const baseCurlAtom = atom>(async (get) => { }; }); -const curlAtom = loadable(baseCurlAtom); +export const curlAtom = loadable(baseCurlAtom); // Syntax highlighting component for curl commands -const SyntaxHighlightedCurl = ({ text }: { text: string }) => { +const SyntaxHighlightedCurl = memo(({ text }: { text: string }) => { const [highlightedHtml, setHighlightedHtml] = useState(''); const [highlighter, setHighlighter] = useState(undefined); const { theme } = useTheme(); @@ -221,14 +221,39 @@ const SyntaxHighlightedCurl = ({ text }: { text: string }) => { /> ); -}; +}, (prev, next) => prev.text === next.text); export const PromptPreviewCurl = () => { const curl = useAtomValue(curlAtom); + const [lastCurl, setLastCurl] = useState< + | { curlTextWithoutSecrets: string; curlTextWithSecrets: string } + | undefined + >(undefined); + + useEffect(() => { + if (curl.state === 'hasData' && curl.data && !(curl.data instanceof Error)) { + setLastCurl(curl.data); + } + }, [curl]); // Memoize the rendered content to prevent unnecessary re-renders const renderedContent = useMemo(() => { if (curl.state === 'loading') { + // While loading, show the last known cURL if available, otherwise a loader + if (lastCurl) { + return ( +
+ + +
+ ); + } return ; } @@ -264,7 +289,7 @@ export const PromptPreviewCurl = () => { ); - }, [curl]); + }, [curl, lastCurl]); return renderedContent; }; diff --git a/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-render-wrapper.tsx b/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-render-wrapper.tsx index b21065327b..7506ef30c9 100644 --- a/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-render-wrapper.tsx +++ b/typescript/packages/playground-common/src/shared/baml-project-panel/playground-panel/prompt-preview/prompt-render-wrapper.tsx @@ -12,7 +12,7 @@ import { selectionAtom } from '../atoms'; import { displaySettingsAtom } from '../preview-toolbar'; import { PromptPreviewContent } from './prompt-preview-content'; import { renderedPromptAtom } from './prompt-preview-content'; -import { PromptPreviewCurl } from './prompt-preview-curl'; +import { PromptPreviewCurl, curlAtom } from './prompt-preview-curl'; import { ClientGraphView } from './test-panel/components/ClientGraphView'; import { MermaidGraphView } from './test-panel/components/MermaidGraphView'; @@ -87,6 +87,8 @@ export const PromptRenderWrapper = () => { const [showCopied, setShowCopied] = React.useState(false); const { open: isSidebarOpen } = useSidebar(); const isBetaEnabled = useAtomValue(betaFeatureEnabledAtom); + const [activeTab, setActiveTab] = React.useState<'preview' | 'curl' | 'client-graph' | 'mermaid-graph'>('preview'); + const curl = useAtomValue(curlAtom); // Hide text when sidebar is open or on smaller screens const getButtonTextClass = () => { @@ -97,6 +99,20 @@ export const PromptRenderWrapper = () => { }; const handleCopy = () => { + // If the cURL tab is active, copy the generated cURL (without secrets) + if (activeTab === 'curl') { + if (curl.state === 'hasData' && curl.data && !(curl.data instanceof Error)) { + const text = curl.data.curlTextWithoutSecrets ?? ''; + if (text) { + void navigator.clipboard.writeText(text); + setShowCopied(true); + setTimeout(() => setShowCopied(false), 1500); + } + } + return; + } + + // Otherwise copy the human-readable prompt preview if (!renderedPrompt) return; navigator.clipboard.writeText( renderedPrompt @@ -113,7 +129,7 @@ export const PromptRenderWrapper = () => { return ( // this used to be flex flex-col h-full min-h-0 - + setActiveTab(v as any)} className="flex flex-col min-h-0">
@@ -137,7 +153,7 @@ export const PromptRenderWrapper = () => { )} - {showCopied ? 'Copied!' : 'Copy Prompt'} + {showCopied ? 'Copied!' : activeTab === 'curl' ? 'Copy cURL' : 'Copy Prompt'}
diff --git a/typescript/packages/playground-common/src/shared/baml-project-panel/vscode.ts b/typescript/packages/playground-common/src/shared/baml-project-panel/vscode.ts index 14acc320a8..fdbdcfb693 100644 --- a/typescript/packages/playground-common/src/shared/baml-project-panel/vscode.ts +++ b/typescript/packages/playground-common/src/shared/baml-project-panel/vscode.ts @@ -349,7 +349,7 @@ class VSCodeAPIWrapper { * @returns Promise Binary file contents * @throws Error if file cannot be read or loaded */ - public async loadMediaFile(path: string): Promise { + public loadMediaFile = async (path: string): Promise => { try { if (this.isVscode()) { // VSCode: Request file contents directly via workspace API