diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index dbb195a85d5..4bc31044b65 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2056,7 +2056,9 @@ dependencies = [ "codex-protocol", "codex-utils-absolute-path", "notify", + "pretty_assertions", "regex-lite", + "reqwest", "serde_json", "shlex", "tempfile", diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9a21f7bfe28..5f0e322dae9 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2476,6 +2476,7 @@ async fn try_run_turn( let mut needs_follow_up = false; let mut last_agent_message: Option = None; let mut active_item: Option = None; + let mut should_emit_turn_diff = false; let receiving_span = info_span!("receiving_stream"); let outcome: CodexResult = loop { let handle_responses = info_span!( @@ -2551,14 +2552,7 @@ async fn try_run_turn( } => { sess.update_token_usage_info(&turn_context, token_usage.as_ref()) .await; - let unified_diff = { - let mut tracker = turn_diff_tracker.lock().await; - tracker.get_unified_diff() - }; - if let Ok(Some(unified_diff)) = unified_diff { - let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); - sess.send_event(&turn_context, msg).await; - } + should_emit_turn_diff = true; break Ok(TurnRunResult { needs_follow_up, @@ -2632,7 +2626,18 @@ async fn try_run_turn( } }; - drain_in_flight(&mut in_flight, sess, turn_context).await?; + drain_in_flight(&mut in_flight, sess.clone(), turn_context.clone()).await?; + + if should_emit_turn_diff { + let unified_diff = { + let mut tracker = turn_diff_tracker.lock().await; + tracker.get_unified_diff() + }; + if let Ok(Some(unified_diff)) = unified_diff { + let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); + sess.clone().send_event(&turn_context, msg).await; + } + } outcome } diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index e1c9a652525..1373fdf2482 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -16,7 +16,6 @@ use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use futures::Future; -use tracing::Instrument; use tracing::debug; use tracing::instrument; @@ -59,16 +58,10 @@ pub(crate) async fn handle_output_item_done( .await; let cancellation_token = ctx.cancellation_token.child_token(); - let tool_runtime = ctx.tool_runtime.clone(); - let tool_future: InFlightFuture<'static> = Box::pin( - async move { - let response_input = tool_runtime - .handle_tool_call(call, cancellation_token) - .await?; - Ok(response_input) - } - .in_current_span(), + ctx.tool_runtime + .clone() + .handle_tool_call(call, cancellation_token), ); output.needs_follow_up = true; diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index feef518e57f..d146507e5ca 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -47,7 +47,7 @@ impl ToolCallRuntime { #[instrument(skip_all, fields(call = ?call))] pub(crate) fn handle_tool_call( - &self, + self, call: ToolCall, cancellation_token: CancellationToken, ) -> impl std::future::Future> { diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index 95ea788b4c6..0f0abbb3cb2 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -22,3 +22,7 @@ tokio = { workspace = true, features = ["time"] } walkdir = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } + +[dev-dependencies] +pretty_assertions = { workspace = true } +reqwest = { workspace = true } diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index 61bac30cb17..dc0cc53f466 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -14,6 +14,7 @@ use std::path::PathBuf; use assert_cmd::cargo::cargo_bin; pub mod responses; +pub mod streaming_sse; pub mod test_codex; pub mod test_codex_exec; diff --git a/codex-rs/core/tests/common/streaming_sse.rs b/codex-rs/core/tests/common/streaming_sse.rs new file mode 100644 index 00000000000..4f1b3673b0f --- /dev/null +++ b/codex-rs/core/tests/common/streaming_sse.rs @@ -0,0 +1,680 @@ +use std::collections::VecDeque; +use std::sync::Arc; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::sync::Mutex as TokioMutex; +use tokio::sync::oneshot; + +/// Streaming SSE chunk payload gated by a per-chunk signal. +#[derive(Debug)] +pub struct StreamingSseChunk { + pub gate: Option>, + pub body: String, +} + +/// Minimal streaming SSE server for tests that need gated per-chunk delivery. +pub struct StreamingSseServer { + uri: String, + shutdown: oneshot::Sender<()>, + task: tokio::task::JoinHandle<()>, +} + +impl StreamingSseServer { + pub fn uri(&self) -> &str { + &self.uri + } + + pub async fn shutdown(self) { + let _ = self.shutdown.send(()); + let _ = self.task.await; + } +} + +/// Starts a lightweight HTTP server that supports: +/// - GET /v1/models -> empty models response +/// - POST /v1/responses -> SSE stream gated per-chunk, served in order +/// +/// Returns the server handle and a list of receivers that fire when each +/// response stream finishes sending its final chunk. +pub async fn start_streaming_sse_server( + responses: Vec>, +) -> (StreamingSseServer, Vec>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind streaming SSE server"); + let addr = listener.local_addr().expect("streaming SSE server address"); + let uri = format!("http://{addr}"); + + let mut completion_senders = Vec::with_capacity(responses.len()); + let mut completion_receivers = Vec::with_capacity(responses.len()); + for _ in 0..responses.len() { + let (tx, rx) = oneshot::channel(); + completion_senders.push(tx); + completion_receivers.push(rx); + } + + let state = Arc::new(TokioMutex::new(StreamingSseState { + responses: VecDeque::from(responses), + completions: VecDeque::from(completion_senders), + })); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + accept_res = listener.accept() => { + let (mut stream, _) = accept_res.expect("accept streaming SSE connection"); + let state = Arc::clone(&state); + tokio::spawn(async move { + let (request, body_prefix) = read_http_request(&mut stream).await; + let Some((method, path)) = parse_request_line(&request) else { + let _ = write_http_response(&mut stream, 400, "bad request", "text/plain").await; + return; + }; + + if method == "GET" && path == "/v1/models" { + if drain_request_body(&mut stream, &request, body_prefix) + .await + .is_err() + { + let _ = write_http_response(&mut stream, 400, "bad request", "text/plain").await; + return; + } + let body = serde_json::json!({ + "data": [], + "object": "list" + }) + .to_string(); + let _ = write_http_response(&mut stream, 200, &body, "application/json").await; + return; + } + + if method == "POST" && path == "/v1/responses" { + if drain_request_body(&mut stream, &request, body_prefix) + .await + .is_err() + { + let _ = write_http_response(&mut stream, 400, "bad request", "text/plain").await; + return; + } + let Some((chunks, completion)) = take_next_stream(&state).await else { + let _ = write_http_response(&mut stream, 500, "no responses queued", "text/plain").await; + return; + }; + + if write_sse_headers(&mut stream).await.is_err() { + return; + } + + for chunk in chunks { + if let Some(gate) = chunk.gate + && gate.await.is_err() { + return; + } + if stream.write_all(chunk.body.as_bytes()).await.is_err() { + return; + } + let _ = stream.flush().await; + } + + let _ = completion.send(unix_ms_now()); + let _ = stream.shutdown().await; + return; + } + + let _ = write_http_response(&mut stream, 404, "not found", "text/plain").await; + }); + } + } + } + }); + + ( + StreamingSseServer { + uri, + shutdown: shutdown_tx, + task, + }, + completion_receivers, + ) +} + +struct StreamingSseState { + responses: VecDeque>, + completions: VecDeque>, +} + +async fn take_next_stream( + state: &TokioMutex, +) -> Option<(Vec, oneshot::Sender)> { + let mut guard = state.lock().await; + let chunks = guard.responses.pop_front()?; + let completion = guard.completions.pop_front()?; + Some((chunks, completion)) +} + +async fn read_http_request(stream: &mut tokio::net::TcpStream) -> (String, Vec) { + let mut buf = Vec::new(); + let mut scratch = [0u8; 1024]; + loop { + let read = stream.read(&mut scratch).await.unwrap_or(0); + if read == 0 { + break; + } + buf.extend_from_slice(&scratch[..read]); + if let Some(end) = header_terminator_index(&buf) { + let header_end = end + 4; + let header = String::from_utf8_lossy(&buf[..header_end]).into_owned(); + let rest = buf[header_end..].to_vec(); + return (header, rest); + } + } + (String::from_utf8_lossy(&buf).into_owned(), Vec::new()) +} + +fn parse_request_line(request: &str) -> Option<(&str, &str)> { + let line = request.lines().next()?; + let mut parts = line.split_whitespace(); + let method = parts.next()?; + let path = parts.next()?; + Some((method, path)) +} + +fn header_terminator_index(buf: &[u8]) -> Option { + buf.windows(4).position(|w| w == b"\r\n\r\n") +} + +fn content_length(headers: &str) -> Option { + headers.lines().skip(1).find_map(|line| { + let mut parts = line.splitn(2, ':'); + let name = parts.next()?.trim(); + let value = parts.next()?.trim(); + if name.eq_ignore_ascii_case("content-length") { + value.parse::().ok() + } else { + None + } + }) +} + +async fn drain_request_body( + stream: &mut tokio::net::TcpStream, + headers: &str, + mut body_prefix: Vec, +) -> std::io::Result<()> { + let Some(content_len) = content_length(headers) else { + return Ok(()); + }; + + if body_prefix.len() > content_len { + body_prefix.truncate(content_len); + } + + let remaining = content_len.saturating_sub(body_prefix.len()); + if remaining == 0 { + return Ok(()); + } + + let mut rest = vec![0u8; remaining]; + stream.read_exact(&mut rest).await?; + Ok(()) +} + +async fn write_sse_headers(stream: &mut tokio::net::TcpStream) -> std::io::Result<()> { + let headers = "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\n\r\n"; + stream.write_all(headers.as_bytes()).await +} + +async fn write_http_response( + stream: &mut tokio::net::TcpStream, + status: i64, + body: &str, + content_type: &str, +) -> std::io::Result<()> { + let body_len = body.len(); + let headers = format!( + "HTTP/1.1 {status} OK\r\ncontent-type: {content_type}\r\ncontent-length: {body_len}\r\nconnection: close\r\n\r\n" + ); + stream.write_all(headers.as_bytes()).await?; + stream.write_all(body.as_bytes()).await?; + stream.shutdown().await +} + +fn unix_ms_now() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as i64 +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use reqwest::StatusCode; + use tokio::net::TcpStream; + use tokio::time::Duration; + use tokio::time::timeout; + + fn split_response(response: &str) -> (&str, &str) { + response + .split_once("\r\n\r\n") + .expect("response missing header separator") + } + + fn status_code(headers: &str) -> u16 { + let line = headers.lines().next().expect("status line"); + let mut parts = line.split_whitespace(); + let _ = parts.next(); + let status = parts.next().expect("status code"); + status.parse().expect("parse status code") + } + + fn header_value<'a>(headers: &'a str, name: &str) -> Option<&'a str> { + headers.lines().skip(1).find_map(|line| { + let mut parts = line.splitn(2, ':'); + let key = parts.next()?.trim(); + let value = parts.next()?.trim(); + if key.eq_ignore_ascii_case(name) { + Some(value) + } else { + None + } + }) + } + + async fn connect(uri: &str) -> TcpStream { + let addr = uri.strip_prefix("http://").expect("uri should be http"); + TcpStream::connect(addr) + .await + .expect("connect to streaming SSE server") + } + + async fn read_to_end(stream: &mut TcpStream) -> String { + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await.expect("read response"); + String::from_utf8_lossy(&buf).into_owned() + } + + async fn read_until(stream: &mut TcpStream, needle: &str) -> (String, String) { + let mut buf = Vec::new(); + let mut scratch = [0u8; 256]; + let needle_bytes = needle.as_bytes(); + loop { + let read = stream.read(&mut scratch).await.expect("read response"); + if read == 0 { + break; + } + buf.extend_from_slice(&scratch[..read]); + if let Some(pos) = buf + .windows(needle_bytes.len()) + .position(|window| window == needle_bytes) + { + let end = pos + needle_bytes.len(); + let headers = String::from_utf8_lossy(&buf[..end]).into_owned(); + let remainder = String::from_utf8_lossy(&buf[end..]).into_owned(); + return (headers, remainder); + } + } + (String::from_utf8_lossy(&buf).into_owned(), String::new()) + } + + async fn send_request(stream: &mut TcpStream, request: &str) { + stream + .write_all(request.as_bytes()) + .await + .expect("write request"); + } + + #[tokio::test] + async fn get_models_returns_empty_list() { + let (server, _) = start_streaming_sse_server(Vec::new()).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "GET /v1/models HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n", + ) + .await; + let response = read_to_end(&mut stream).await; + let (headers, body) = split_response(&response); + assert_eq!(status_code(headers), 200); + assert_eq!( + header_value(headers, "content-type"), + Some("application/json") + ); + let parsed: serde_json::Value = serde_json::from_str(body).expect("parse json body"); + assert_eq!( + parsed, + serde_json::json!({ + "data": [], + "object": "list" + }) + ); + server.shutdown().await; + } + + #[tokio::test] + async fn post_responses_streams_in_order_and_closes() { + let chunks = vec![ + StreamingSseChunk { + gate: None, + body: "event: one\n\n".to_string(), + }, + StreamingSseChunk { + gate: None, + body: "event: two\n\n".to_string(), + }, + ]; + let (server, mut completions) = start_streaming_sse_server(vec![chunks]).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let response = read_to_end(&mut stream).await; + let (headers, body) = split_response(&response); + assert_eq!(status_code(headers), 200); + assert_eq!( + header_value(headers, "content-type"), + Some("text/event-stream") + ); + assert_eq!(body, "event: one\n\nevent: two\n\n"); + let mut extra = [0u8; 1]; + let read = stream.read(&mut extra).await.expect("read after eof"); + assert_eq!(read, 0); + let completion = completions.pop().expect("completion receiver"); + let timestamp = completion.await.expect("completion timestamp"); + assert!(timestamp > 0); + server.shutdown().await; + } + + #[tokio::test] + async fn none_gate_streams_immediately() { + let chunks = vec![StreamingSseChunk { + gate: None, + body: "event: immediate\n\n".to_string(), + }]; + let (server, _) = start_streaming_sse_server(vec![chunks]).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let (headers, remainder) = read_until(&mut stream, "\r\n\r\n").await; + let (headers, _) = split_response(&headers); + assert_eq!(status_code(headers), 200); + let immediate = format!("{remainder}{}", read_to_end(&mut stream).await); + assert_eq!(immediate, "event: immediate\n\n"); + server.shutdown().await; + } + + #[tokio::test] + async fn post_responses_with_no_queue_returns_500() { + let (server, _) = start_streaming_sse_server(Vec::new()).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let response = read_to_end(&mut stream).await; + let (headers, body) = split_response(&response); + assert_eq!(status_code(headers), 500); + assert_eq!(header_value(headers, "content-type"), Some("text/plain")); + assert_eq!(body, "no responses queued"); + server.shutdown().await; + } + + #[tokio::test] + async fn gated_chunks_wait_for_signal_and_preserve_order() { + let (gate_one_tx, gate_one_rx) = oneshot::channel(); + let (gate_two_tx, gate_two_rx) = oneshot::channel(); + let chunks = vec![ + StreamingSseChunk { + gate: Some(gate_one_rx), + body: "event: one\n\n".to_string(), + }, + StreamingSseChunk { + gate: Some(gate_two_rx), + body: "event: two\n\n".to_string(), + }, + ]; + let (server, _) = start_streaming_sse_server(vec![chunks]).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let (headers, remainder) = read_until(&mut stream, "\r\n\r\n").await; + let (headers, _) = split_response(&headers); + assert_eq!(status_code(headers), 200); + assert_eq!( + header_value(headers, "content-type"), + Some("text/event-stream") + ); + assert!( + remainder.is_empty(), + "unexpected body before gate: {remainder:?}" + ); + let mut scratch = [0u8; 32]; + let pending = timeout(Duration::from_millis(200), stream.read(&mut scratch)).await; + assert!(pending.is_err()); + + let _ = gate_one_tx.send(()); + let mut first_chunk = vec![0u8; "event: one\n\n".len()]; + stream + .read_exact(&mut first_chunk) + .await + .expect("read first chunk"); + assert_eq!(String::from_utf8_lossy(&first_chunk), "event: one\n\n"); + let pending = timeout(Duration::from_millis(200), stream.read(&mut scratch)).await; + assert!(pending.is_err()); + + let _ = gate_two_tx.send(()); + let remaining = read_to_end(&mut stream).await; + assert_eq!(remaining, "event: two\n\n"); + server.shutdown().await; + } + + #[tokio::test] + async fn multiple_responses_are_fifo_and_completion_timestamps_monotonic() { + let first_chunks = vec![StreamingSseChunk { + gate: None, + body: "event: first\n\n".to_string(), + }]; + let second_chunks = vec![StreamingSseChunk { + gate: None, + body: "event: second\n\n".to_string(), + }]; + let (server, mut completions) = + start_streaming_sse_server(vec![first_chunks, second_chunks]).await; + + let mut first_stream = connect(server.uri()).await; + send_request( + &mut first_stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let first_response = read_to_end(&mut first_stream).await; + let (_, first_body) = split_response(&first_response); + assert_eq!(first_body, "event: first\n\n"); + + let mut second_stream = connect(server.uri()).await; + send_request( + &mut second_stream, + "POST /v1/responses HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let second_response = read_to_end(&mut second_stream).await; + let (_, second_body) = split_response(&second_response); + assert_eq!(second_body, "event: second\n\n"); + + let first_completion = completions.remove(0); + let second_completion = completions.remove(0); + let first_timestamp = first_completion.await.expect("first completion"); + let second_timestamp = second_completion.await.expect("second completion"); + assert!(first_timestamp > 0); + assert!(second_timestamp > 0); + assert!(first_timestamp <= second_timestamp); + assert!(completions.is_empty()); + server.shutdown().await; + } + + #[tokio::test] + async fn unknown_route_returns_404() { + let (server, _) = start_streaming_sse_server(Vec::new()).await; + let mut stream = connect(server.uri()).await; + send_request( + &mut stream, + "GET /v1/unknown HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n", + ) + .await; + let response = read_to_end(&mut stream).await; + let (headers, body) = split_response(&response); + assert_eq!(status_code(headers), 404); + assert_eq!(header_value(headers, "content-type"), Some("text/plain")); + assert_eq!(body, "not found"); + server.shutdown().await; + } + + #[tokio::test] + async fn malformed_request_returns_400() { + let (server, _) = start_streaming_sse_server(Vec::new()).await; + let mut stream = connect(server.uri()).await; + send_request(&mut stream, "BAD\r\n\r\n").await; + let response = read_to_end(&mut stream).await; + let (headers, body) = split_response(&response); + assert_eq!(status_code(headers), 400); + assert_eq!(header_value(headers, "content-type"), Some("text/plain")); + assert_eq!(body, "bad request"); + server.shutdown().await; + } + + #[tokio::test] + async fn responses_post_drains_request_body() { + let response_body = r#"event: response.completed +data: {"type":"response.completed","response":{"id":"resp-1"}} + +"#; + let (server, mut completions) = start_streaming_sse_server(vec![vec![StreamingSseChunk { + gate: None, + body: response_body.to_string(), + }]]) + .await; + + let url = format!("{}/v1/responses", server.uri()); + let payload = serde_json::json!({ + "model": "gpt-5.1", + "instructions": "test", + "input": [{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hello"}]}], + "stream": true + }); + + let resp = reqwest::Client::new() + .post(url) + .json(&payload) + .send() + .await + .expect("send request"); + assert_eq!(resp.status(), StatusCode::OK); + + let bytes = resp.bytes().await.expect("read response body"); + assert_eq!(bytes, response_body.as_bytes()); + + let completion = completions.remove(0); + let completed_at = completion.await.expect("completion timestamp"); + assert!(completed_at > 0); + + server.shutdown().await; + } + + #[tokio::test] + async fn read_http_request_returns_after_header_terminator() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test listener"); + let addr = listener.local_addr().expect("listener address"); + let (tx, rx) = oneshot::channel(); + let server_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("accept client"); + let (request, body) = read_http_request(&mut stream).await; + let _ = tx.send((request, body)); + }); + + let mut client = TcpStream::connect(addr) + .await + .expect("connect to test listener"); + let request = "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n"; + client + .write_all(request.as_bytes()) + .await + .expect("write request"); + let (received, body) = timeout(Duration::from_millis(200), rx) + .await + .expect("read_http_request timed out") + .expect("receive request"); + assert_eq!(received, request); + assert!(body.is_empty()); + drop(client); + let _ = server_task.await; + } + + #[test] + fn parse_request_line_handles_valid_and_invalid() { + assert_eq!(parse_request_line(""), None); + assert_eq!(parse_request_line("BAD"), None); + assert_eq!( + parse_request_line("GET /v1/models HTTP/1.1"), + Some(("GET", "/v1/models")) + ); + } + + #[tokio::test] + async fn take_next_stream_consumes_in_lockstep() { + let (first_tx, first_rx) = oneshot::channel(); + let (second_tx, second_rx) = oneshot::channel(); + let state = TokioMutex::new(StreamingSseState { + responses: VecDeque::from(vec![ + vec![StreamingSseChunk { + gate: None, + body: "first".to_string(), + }], + vec![StreamingSseChunk { + gate: None, + body: "second".to_string(), + }], + ]), + completions: VecDeque::from(vec![first_tx, second_tx]), + }); + + let (first_chunks, first_completion) = + take_next_stream(&state).await.expect("first stream"); + assert_eq!(first_chunks[0].body, "first"); + let _ = first_completion.send(11); + assert_eq!(first_rx.await.expect("first completion"), 11); + + let (second_chunks, second_completion) = + take_next_stream(&state).await.expect("second stream"); + assert_eq!(second_chunks[0].body, "second"); + let _ = second_completion.send(22); + assert_eq!(second_rx.await.expect("second completion"), 22); + + let third = take_next_stream(&state).await; + assert!(third.is_none()); + } + + #[tokio::test] + async fn shutdown_terminates_accept_loop() { + let (server, _) = start_streaming_sse_server(Vec::new()).await; + let shutdown = timeout(Duration::from_millis(200), server.shutdown()).await; + assert!(shutdown.is_ok()); + } +} diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 7a12a24a691..59379d76867 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -25,6 +25,7 @@ use wiremock::MockServer; use crate::load_default_config_for_test; use crate::responses::get_responses_request_bodies; use crate::responses::start_mock_server; +use crate::streaming_sse::StreamingSseServer; use crate::wait_for_event; type ConfigMutator = dyn FnOnce(&mut Config) + Send; @@ -89,6 +90,16 @@ impl TestCodexBuilder { self.build_with_home(server, home, None).await } + pub async fn build_with_streaming_server( + &mut self, + server: &StreamingSseServer, + ) -> anyhow::Result { + let base_url = server.uri(); + let home = Arc::new(TempDir::new()?); + self.build_with_home_and_base_url(format!("{base_url}/v1"), home, None) + .await + } + pub async fn resume( &mut self, server: &wiremock::MockServer, @@ -104,8 +115,28 @@ impl TestCodexBuilder { home: Arc, resume_from: Option, ) -> anyhow::Result { - let (config, cwd) = self.prepare_config(server, &home).await?; + let base_url = format!("{}/v1", server.uri()); + let (config, cwd) = self.prepare_config(base_url, &home).await?; + self.build_from_config(config, cwd, home, resume_from).await + } + + async fn build_with_home_and_base_url( + &mut self, + base_url: String, + home: Arc, + resume_from: Option, + ) -> anyhow::Result { + let (config, cwd) = self.prepare_config(base_url, &home).await?; + self.build_from_config(config, cwd, home, resume_from).await + } + async fn build_from_config( + &mut self, + config: Config, + cwd: Arc, + home: Arc, + resume_from: Option, + ) -> anyhow::Result { let auth = self.auth.clone(); let conversation_manager = ConversationManager::with_models_provider_and_home( auth.clone(), @@ -139,11 +170,11 @@ impl TestCodexBuilder { async fn prepare_config( &mut self, - server: &wiremock::MockServer, + base_url: String, home: &TempDir, ) -> anyhow::Result<(Config, Arc)> { let model_provider = ModelProviderInfo { - base_url: Some(format!("{}/v1", server.uri())), + base_url: Some(base_url), ..built_in_model_providers()["openai"].clone() }; let cwd = Arc::new(TempDir::new()?); diff --git a/codex-rs/core/tests/suite/tool_parallelism.rs b/codex-rs/core/tests/suite/tool_parallelism.rs index abb55f3a14d..0ec5221f407 100644 --- a/codex-rs/core/tests/suite/tool_parallelism.rs +++ b/codex-rs/core/tests/suite/tool_parallelism.rs @@ -1,6 +1,7 @@ #![cfg(not(target_os = "windows"))] #![allow(clippy::unwrap_used)] +use std::fs; use std::time::Duration; use std::time::Instant; @@ -13,16 +14,22 @@ use codex_protocol::user_input::UserInput; use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::ev_shell_command_call_with_args; use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; +use core_test_support::streaming_sse::StreamingSseChunk; +use core_test_support::streaming_sse::start_streaming_sse_server; use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; use serde_json::Value; use serde_json::json; +use tokio::sync::oneshot; async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> { let session_model = test.session_configured.model.clone(); @@ -280,3 +287,123 @@ async fn tool_results_grouped() -> anyhow::Result<()> { Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_tools_start_before_response_completed_when_stream_delayed() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let output_file = tempfile::NamedTempFile::new()?; + let output_path = output_file.path(); + let first_response_id = "resp-1"; + let second_response_id = "resp-2"; + + let command = format!( + "perl -MTime::HiRes -e 'print int(Time::HiRes::time()*1000), \"\\n\"' >> \"{}\"", + output_path.display() + ); + let args = json!({ + "command": command, + "timeout_ms": 1_000, + }); + + let first_chunk = sse(vec![ + ev_response_created(first_response_id), + ev_shell_command_call_with_args("call-1", &args), + ev_shell_command_call_with_args("call-2", &args), + ev_shell_command_call_with_args("call-3", &args), + ev_shell_command_call_with_args("call-4", &args), + ]); + let second_chunk = sse(vec![ev_completed(first_response_id)]); + let follow_up = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed(second_response_id), + ]); + + let (first_gate_tx, first_gate_rx) = oneshot::channel(); + let (completion_gate_tx, completion_gate_rx) = oneshot::channel(); + let (follow_up_gate_tx, follow_up_gate_rx) = oneshot::channel(); + let (streaming_server, completion_receivers) = start_streaming_sse_server(vec![ + vec![ + StreamingSseChunk { + gate: Some(first_gate_rx), + body: first_chunk, + }, + StreamingSseChunk { + gate: Some(completion_gate_rx), + body: second_chunk, + }, + ], + vec![StreamingSseChunk { + gate: Some(follow_up_gate_rx), + body: follow_up, + }], + ]) + .await; + + let mut builder = test_codex().with_model("gpt-5.1"); + let test = builder + .build_with_streaming_server(&streaming_server) + .await?; + + let session_model = test.session_configured.model.clone(); + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "stream delayed completion".into(), + }], + final_output_json_schema: None, + cwd: test.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let _ = first_gate_tx.send(()); + let _ = follow_up_gate_tx.send(()); + + let timestamps = tokio::time::timeout(Duration::from_secs(1), async { + loop { + let contents = fs::read_to_string(output_path)?; + let timestamps = contents + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| { + line.trim() + .parse::() + .map_err(|err| anyhow::anyhow!("invalid timestamp {line:?}: {err}")) + }) + .collect::, _>>()?; + if timestamps.len() == 4 { + return Ok::<_, anyhow::Error>(timestamps); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await??; + + let _ = completion_gate_tx.send(()); + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let mut completion_iter = completion_receivers.into_iter(); + let completed_at = completion_iter + .next() + .expect("completion receiver missing") + .await + .expect("completion timestamp missing"); + let count = i64::try_from(timestamps.len()).expect("timestamp count fits in i64"); + assert_eq!(count, 4); + + for timestamp in timestamps { + assert!( + timestamp < completed_at, + "timestamp {timestamp} should be before completed {completed_at}" + ); + } + + streaming_server.shutdown().await; + + Ok(()) +}