Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 139 additions & 22 deletions codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::api_bridge::CoreAuthProvider;
use crate::auth::add_auth_headers_to_header_map;
use crate::endpoint::realtime_websocket::methods_common::conversation_handoff_append_message;
use crate::endpoint::realtime_websocket::methods_common::conversation_item_create_message;
use crate::endpoint::realtime_websocket::methods_common::normalized_session_mode;
use crate::endpoint::realtime_websocket::methods_common::session_update_session;
use crate::endpoint::realtime_websocket::methods_common::webrtc_intent;
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
Expand All @@ -24,6 +27,9 @@ use interceptor::registry::Registry;
use opus_rs::Application;
use opus_rs::OpusDecoder;
use opus_rs::OpusEncoder;
use reqwest::multipart::Form;
use reqwest::multipart::Part;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
Expand Down Expand Up @@ -362,12 +368,13 @@ fn append_transcript_delta(entries: &mut Vec<RealtimeTranscriptEntry>, role: &st
}

pub struct RealtimeWebRtcClient {
auth: CoreAuthProvider,
provider: Provider,
}

impl RealtimeWebRtcClient {
pub fn new(provider: Provider) -> Self {
Self { provider }
pub fn new(provider: Provider, auth: CoreAuthProvider) -> Self {
Self { auth, provider }
}

pub async fn connect(
Expand All @@ -380,9 +387,12 @@ impl RealtimeWebRtcClient {
let calls_url = calls_url_from_api_url(
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
config.model.as_deref(),
config.event_parser,
)?;

let headers = merge_request_headers(
&self.auth,
&self.provider.headers,
with_session_id_header(extra_headers, config.session_id.as_deref())?,
default_headers,
Expand Down Expand Up @@ -412,12 +422,14 @@ impl RealtimeWebRtcClient {
}

fn merge_request_headers(
auth: &CoreAuthProvider,
provider_headers: &HeaderMap,
extra_headers: HeaderMap,
default_headers: HeaderMap,
) -> HeaderMap {
let mut headers = provider_headers.clone();
headers.extend(extra_headers);
add_auth_headers_to_header_map(auth, &mut headers);
for (name, value) in &default_headers {
if let http::header::Entry::Vacant(entry) = headers.entry(name) {
entry.insert(value.clone());
Expand Down Expand Up @@ -445,6 +457,8 @@ fn with_session_id_header(
fn calls_url_from_api_url(
api_url: &str,
query_params: Option<&HashMap<String, String>>,
model: Option<&str>,
event_parser: RealtimeEventParser,
) -> Result<Url, ApiError> {
let mut url = Url::parse(api_url)
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
Expand All @@ -466,10 +480,27 @@ fn calls_url_from_api_url(
}
}

if let Some(query_params) = query_params {
let intent = webrtc_intent(event_parser);
let has_extra_query_params = query_params.is_some_and(|query_params| {
query_params
.iter()
.any(|(key, _)| key != "intent" && !(key == "model" && model.is_some()))
});
if intent.is_some() || model.is_some() || has_extra_query_params {
let mut query = url.query_pairs_mut();
for (key, value) in query_params {
query.append_pair(key, value);
if let Some(intent) = intent {
query.append_pair("intent", intent);
}
if let Some(model) = model {
query.append_pair("model", model);
}
if let Some(query_params) = query_params {
for (key, value) in query_params {
if key == "intent" || (key == "model" && model.is_some()) {
continue;
}
query.append_pair(key, value);
}
}
}

Expand Down Expand Up @@ -509,7 +540,10 @@ fn normalize_realtime_calls_path(url: &mut Url) {

if path.ends_with("/v1/") {
url.set_path(&format!("{path}realtime/calls"));
return;
}

url.set_path(&format!("{}/realtime/calls", path.trim_end_matches('/')));
}

async fn connect_webrtc_transport(
Expand Down Expand Up @@ -825,18 +859,44 @@ async fn fetch_realtime_answer(
{
session_json.insert("model".to_string(), serde_json::Value::String(model));
}
let session_payload = serde_json::to_string(&session_json)
.map_err(|err| ApiError::Stream(format!("failed to encode realtime session: {err}")))?;
let form = reqwest::multipart::Form::new()
.text("sdp", offer_sdp)
.text("session", session_payload);
let response = client
.post(calls_url)
.headers(headers)
.multipart(form)
.send()
.await
.map_err(|err| ApiError::Stream(format!("failed to create realtime WebRTC call: {err}")))?;
let response = if calls_url.path().ends_with("/v1/realtime/calls") {
let session_payload = serde_json::to_string(&session_json)
.map_err(|err| ApiError::Stream(format!("failed to encode realtime session: {err}")))?;
let form = Form::new()
.part(
"sdp",
Part::text(offer_sdp)
.mime_str("application/sdp")
.map_err(|err| {
ApiError::Stream(format!("failed to encode realtime SDP part: {err}"))
})?,
)
.part(
"session",
Part::text(session_payload)
.mime_str("application/json")
.map_err(|err| {
ApiError::Stream(format!("failed to encode realtime session part: {err}"))
})?,
);
client
.post(calls_url)
.headers(headers)
.multipart(form)
.send()
.await
} else {
client
.post(calls_url)
.headers(headers)
.json(&json!({
"sdp": offer_sdp,
"session": session_json,
}))
.send()
.await
}
.map_err(|err| ApiError::Stream(format!("failed to create realtime WebRTC call: {err}")))?;
let status = response.status();
let body = response.text().await.map_err(|err| {
ApiError::Stream(format!("failed to read realtime WebRTC answer SDP: {err}"))
Expand Down Expand Up @@ -961,8 +1021,13 @@ mod tests {
#[test]
fn calls_url_from_api_url_normalizes_http_root() {
let query_params = HashMap::from([("model".to_string(), "gpt-realtime".to_string())]);
let calls_url =
calls_url_from_api_url("http://example.com", Some(&query_params)).expect("calls url");
let calls_url = calls_url_from_api_url(
"http://example.com",
Some(&query_params),
Some("gpt-realtime"),
RealtimeEventParser::RealtimeV2,
)
.expect("calls url");

assert_eq!(
calls_url.as_str(),
Expand All @@ -973,16 +1038,68 @@ mod tests {
#[test]
fn calls_url_from_api_url_preserves_v1_realtime_path_and_query() {
let query_params = HashMap::from([("model".to_string(), "gpt-realtime".to_string())]);
let calls_url =
calls_url_from_api_url("wss://example.com/v1/realtime?foo=bar", Some(&query_params))
.expect("calls url");
let calls_url = calls_url_from_api_url(
"wss://example.com/v1/realtime?foo=bar",
Some(&query_params),
Some("gpt-realtime"),
RealtimeEventParser::RealtimeV2,
)
.expect("calls url");

assert_eq!(
calls_url.as_str(),
"https://example.com/v1/realtime/calls?foo=bar&model=gpt-realtime"
);
}

#[test]
fn calls_url_from_api_url_appends_quicksilver_intent_for_v1() {
let calls_url = calls_url_from_api_url(
"wss://example.com/v1/realtime",
/*query_params*/ None,
Some("quicksilver-test-model"),
RealtimeEventParser::V1,
)
.expect("calls url");

assert_eq!(
calls_url.as_str(),
"https://example.com/v1/realtime/calls?intent=quicksilver&model=quicksilver-test-model"
);
}

#[test]
fn calls_url_from_api_url_omits_intent_for_v2() {
let calls_url = calls_url_from_api_url(
"wss://example.com/v1/realtime",
/*query_params*/ None,
Some("gpt-realtime"),
RealtimeEventParser::RealtimeV2,
)
.expect("calls url");

assert_eq!(
calls_url.as_str(),
"https://example.com/v1/realtime/calls?model=gpt-realtime"
);
}

#[test]
fn calls_url_from_api_url_appends_calls_path_to_chatgpt_base_url() {
let calls_url = calls_url_from_api_url(
"https://chatgpt.com/backend-api/codex",
/*query_params*/ None,
Some("gpt-realtime"),
RealtimeEventParser::RealtimeV2,
)
.expect("calls url");

assert_eq!(
calls_url.as_str(),
"https://chatgpt.com/backend-api/codex/realtime/calls?model=gpt-realtime"
);
}

#[test]
fn parse_session_updated_event() {
let payload = json!({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::endpoint::realtime_websocket::methods_v1::conversation_handoff_append_message as v1_conversation_handoff_append_message;
use crate::endpoint::realtime_websocket::methods_v1::conversation_item_create_message as v1_conversation_item_create_message;
use crate::endpoint::realtime_websocket::methods_v1::session_update_session as v1_session_update_session;
use crate::endpoint::realtime_websocket::methods_v1::webrtc_intent as v1_webrtc_intent;
use crate::endpoint::realtime_websocket::methods_v2::conversation_handoff_append_message as v2_conversation_handoff_append_message;
use crate::endpoint::realtime_websocket::methods_v2::conversation_item_create_message as v2_conversation_item_create_message;
use crate::endpoint::realtime_websocket::methods_v2::session_update_session as v2_session_update_session;
use crate::endpoint::realtime_websocket::methods_v2::webrtc_intent as v2_webrtc_intent;
use crate::endpoint::realtime_websocket::protocol::RealtimeEventParser;
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionMode;
Expand Down Expand Up @@ -57,3 +59,10 @@ pub(super) fn session_update_session(
RealtimeEventParser::RealtimeV2 => v2_session_update_session(instructions, session_mode),
}
}

pub(super) fn webrtc_intent(event_parser: RealtimeEventParser) -> Option<&'static str> {
match event_parser {
RealtimeEventParser::V1 => v1_webrtc_intent(),
RealtimeEventParser::RealtimeV2 => v2_webrtc_intent(),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ pub(super) fn session_update_session(instructions: String) -> SessionUpdateSessi
tool_choice: None,
}
}

pub(super) fn webrtc_intent() -> Option<&'static str> {
Some("quicksilver")
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,7 @@ pub(super) fn session_update_session(
},
}
}

pub(super) fn webrtc_intent() -> Option<&'static str> {
None
}
Loading
Loading