From 6b70e0c5977b2166adcb86e7b7e81040ce5f8021 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 13 May 2026 14:10:58 -0400 Subject: [PATCH 1/2] refactor: tag slot response streaming messages --- crates/coglet-python/src/log_writer.rs | 4 +- crates/coglet/src/bridge/codec.rs | 21 ++----- crates/coglet/src/bridge/protocol.rs | 85 +++++++++++++++++++++++--- crates/coglet/src/orchestrator.rs | 22 ++++++- crates/coglet/src/worker.rs | 17 ++++-- 5 files changed, 115 insertions(+), 34 deletions(-) diff --git a/crates/coglet-python/src/log_writer.rs b/crates/coglet-python/src/log_writer.rs index 22e69398ac..53abd43bbe 100644 --- a/crates/coglet-python/src/log_writer.rs +++ b/crates/coglet-python/src/log_writer.rs @@ -615,11 +615,11 @@ mod tests { let msg = rx.try_recv().unwrap(); match msg { - SlotResponse::Log { source, data } => { + SlotResponse::LogLine { source, data } => { assert_eq!(source, LogSource::Stdout); assert_eq!(data, "hello"); } - _ => panic!("expected Log message"), + _ => panic!("expected LogLine message"), } } diff --git a/crates/coglet/src/bridge/codec.rs b/crates/coglet/src/bridge/codec.rs index 979be41a68..32e97356c5 100644 --- a/crates/coglet/src/bridge/codec.rs +++ b/crates/coglet/src/bridge/codec.rs @@ -159,26 +159,17 @@ mod tests { let mut codec = JsonCodec::::new(); let mut buf = BytesMut::new(); - let resp = SlotResponse::Done { - id: "test".to_string(), - output: Some(serde_json::json!("result")), - predict_time: 1.5, - is_stream: false, + let resp = SlotResponse::OutputChunk { + output: serde_json::json!("result"), + index: 3, }; codec.encode(resp, &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); match decoded { - SlotResponse::Done { - id, - output, - predict_time, - is_stream, - } => { - assert_eq!(id, "test"); - assert_eq!(output, Some(serde_json::json!("result"))); - assert!((predict_time - 1.5).abs() < 0.001); - assert!(!is_stream); + SlotResponse::OutputChunk { output, index } => { + assert_eq!(output, serde_json::json!("result")); + assert_eq!(index, 3); } _ => panic!("wrong variant"), } diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index cd1851fe4e..09edeba792 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -294,15 +294,32 @@ pub enum MetricMode { Append, } +/// Current slot response protocol version. +/// +/// The response enum is already serde-tagged with `type`, so this constant and +/// the optional `ProtocolVersion` message give future protocol changes an +/// explicit marker without adding a second envelope around every message. +pub const SLOT_RESPONSE_PROTOCOL_VERSION: u32 = 1; + /// Messages from worker to parent on slot socket. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum SlotResponse { - Log { + ProtocolVersion { + version: u32, + }, + + LogLine { source: LogSource, data: String, }, + /// Raw binary payload for future WebSocket streaming. + BinaryChunk { + mime_type: String, + data: Vec, + }, + /// Output for a file/path-like output return type or an output that exceeds the size threshold /// for bridge codec serialization. FileOutput { @@ -313,9 +330,10 @@ pub enum SlotResponse { mime_type: Option, }, - /// Streaming output chunk (for generators). - Output { + /// Streaming output chunk for generator and iterator output. + OutputChunk { output: serde_json::Value, + index: u64, }, /// User-emitted metric from the prediction. @@ -338,7 +356,7 @@ pub enum SlotResponse { output: Option, predict_time: f64, /// Predictor signal: true when the output is a list, generator, or - /// iterator — used as fallback when the schema Output type is `Any` + /// iterator, used as fallback when the schema Output type is `Any` /// or unavailable. #[serde(default, skip_serializing_if = "std::ops::Not::not")] is_stream: bool, @@ -499,20 +517,67 @@ mod tests { } #[test] - fn slot_log_serializes() { - let resp = SlotResponse::Log { + fn slot_log_line_serializes() { + let resp = SlotResponse::LogLine { source: LogSource::Stdout, data: "Processing...".to_string(), }; - insta::assert_json_snapshot!(resp); + + assert_eq!( + serde_json::to_value(resp).unwrap(), + json!({ + "type": "log_line", + "source": "stdout", + "data": "Processing..." + }) + ); } #[test] - fn slot_output_serializes() { - let resp = SlotResponse::Output { + fn slot_output_chunk_serializes() { + let resp = SlotResponse::OutputChunk { output: json!("chunk 1"), + index: 7, }; - insta::assert_json_snapshot!(resp); + + assert_eq!( + serde_json::to_value(resp).unwrap(), + json!({ + "type": "output_chunk", + "output": "chunk 1", + "index": 7 + }) + ); + } + + #[test] + fn slot_protocol_version_serializes() { + let resp = SlotResponse::ProtocolVersion { version: 1 }; + + assert_eq!( + serde_json::to_value(resp).unwrap(), + json!({ + "type": "protocol_version", + "version": 1 + }) + ); + } + + #[test] + fn slot_binary_chunk_serializes() { + let resp = SlotResponse::BinaryChunk { + mime_type: "audio/opus".to_string(), + data: vec![1, 2, 3, 4], + }; + + assert_eq!( + serde_json::to_value(resp).unwrap(), + json!({ + "type": "binary_chunk", + "mime_type": "audio/opus", + "data": [1, 2, 3, 4] + }) + ); } #[test] diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index e4514a7e1d..fceac88601 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -953,7 +953,17 @@ async fn run_event_loop( Some((slot_id, result)) = slot_msg_rx.recv() => { match result { - Ok(SlotResponse::Log { source, data }) => { + Ok(SlotResponse::ProtocolVersion { version }) => { + if version != crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION { + tracing::warn!( + %slot_id, + version, + expected = crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION, + "Worker reported unexpected slot response protocol version" + ); + } + } + Ok(SlotResponse::LogLine { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_log(&data); @@ -1005,7 +1015,7 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::Output { output }) => { + Ok(SlotResponse::OutputChunk { output, index: _ }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_output(output); @@ -1021,6 +1031,14 @@ async fn run_event_loop( predictions.remove(&slot_id); } } + Ok(SlotResponse::BinaryChunk { mime_type, data }) => { + tracing::debug!( + %slot_id, + %mime_type, + bytes = data.len(), + "Ignoring binary chunk until WebSocket streaming is implemented" + ); + } Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); let bytes = match std::fs::read(&filename) { diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index 6d040dc62f..3869157297 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -14,7 +14,7 @@ use std::io; use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use futures::{SinkExt, StreamExt}; use tokio::runtime::Handle; @@ -151,6 +151,7 @@ pub struct SlotSender { tx: mpsc::UnboundedSender, output_dir: PathBuf, file_counter: Arc, + output_counter: Arc, } impl SlotSender { @@ -159,9 +160,14 @@ impl SlotSender { tx, output_dir, file_counter: Arc::new(AtomicUsize::new(0)), + output_counter: Arc::new(AtomicU64::new(0)), } } + fn next_output_index(&self) -> u64 { + self.output_counter.fetch_add(1, Ordering::Relaxed) + } + /// Generate a unique filename in the output dir. fn next_output_path(&self, extension: &str) -> PathBuf { let n = self.file_counter.fetch_add(1, Ordering::Relaxed); @@ -173,7 +179,7 @@ impl SlotSender { return Ok(()); } - let msg = SlotResponse::Log { + let msg = SlotResponse::LogLine { source, data: truncate_worker_log(data.to_string()), }; @@ -232,7 +238,7 @@ impl SlotSender { /// Send prediction output, either inline or spilled to disk if too large. pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> { - let msg = build_output_message(&self.output_dir, output)?; + let msg = build_output_message(&self.output_dir, output, self.next_output_index())?; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) @@ -243,6 +249,7 @@ impl SlotSender { fn build_output_message( output_dir: &std::path::Path, output: serde_json::Value, + index: u64, ) -> io::Result { let serialized = serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; @@ -260,7 +267,7 @@ fn build_output_message( mime_type: None, }) } else { - Ok(SlotResponse::Output { output }) + Ok(SlotResponse::OutputChunk { output, index }) } } @@ -872,7 +879,7 @@ async fn run_prediction( // Send output as a separate message (handles spilling for large values). // Skip if null or empty array — those mean "already streamed" (generators). if !output.is_null() && output != serde_json::Value::Array(vec![]) { - let output_msg = match build_output_message(&output_dir, output) { + let output_msg = match build_output_message(&output_dir, output, 0) { Ok(msg) => msg, Err(e) => { tracing::error!(error = %e, "Failed to build output message"); From 261a75d3870159cafddb40c6843a127d670bedc7 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 15:22:35 -0400 Subject: [PATCH 2/2] refactor: clean up slot response protocol - Remove orphaned snapshot files from renamed tests - Remove unused BinaryChunk variant and handler - Add doc comment to ProtocolVersion explaining handshake intent - Wire up worker to send ProtocolVersion on each slot at startup --- crates/coglet/src/bridge/protocol.rs | 32 +++++-------------- ..._protocol__tests__slot_log_serializes.snap | 9 ------ ...otocol__tests__slot_output_serializes.snap | 8 ----- crates/coglet/src/orchestrator.rs | 8 ----- crates/coglet/src/worker.rs | 15 ++++++++- 5 files changed, 22 insertions(+), 50 deletions(-) delete mode 100644 crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_log_serializes.snap delete mode 100644 crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_output_serializes.snap diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index 09edeba792..0ef46a4e68 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -305,6 +305,11 @@ pub const SLOT_RESPONSE_PROTOCOL_VERSION: u32 = 1; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum SlotResponse { + /// Protocol version handshake message. + /// + /// Intended to be sent by the worker when the slot connection opens so the + /// orchestrator can detect version mismatches and adjust behavior. Currently + /// nothing sends this; it is scaffolding for future protocol evolution. ProtocolVersion { version: u32, }, @@ -314,12 +319,6 @@ pub enum SlotResponse { data: String, }, - /// Raw binary payload for future WebSocket streaming. - BinaryChunk { - mime_type: String, - data: Vec, - }, - /// Output for a file/path-like output return type or an output that exceeds the size threshold /// for bridge codec serialization. FileOutput { @@ -552,7 +551,9 @@ mod tests { #[test] fn slot_protocol_version_serializes() { - let resp = SlotResponse::ProtocolVersion { version: 1 }; + let resp = SlotResponse::ProtocolVersion { + version: SLOT_RESPONSE_PROTOCOL_VERSION, + }; assert_eq!( serde_json::to_value(resp).unwrap(), @@ -563,23 +564,6 @@ mod tests { ); } - #[test] - fn slot_binary_chunk_serializes() { - let resp = SlotResponse::BinaryChunk { - mime_type: "audio/opus".to_string(), - data: vec![1, 2, 3, 4], - }; - - assert_eq!( - serde_json::to_value(resp).unwrap(), - json!({ - "type": "binary_chunk", - "mime_type": "audio/opus", - "data": [1, 2, 3, 4] - }) - ); - } - #[test] fn slot_done_serializes() { let resp = SlotResponse::Done { diff --git a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_log_serializes.snap b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_log_serializes.snap deleted file mode 100644 index 376290a263..0000000000 --- a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_log_serializes.snap +++ /dev/null @@ -1,9 +0,0 @@ ---- -source: coglet/src/bridge/protocol.rs -expression: resp ---- -{ - "type": "log", - "source": "stdout", - "data": "Processing..." -} diff --git a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_output_serializes.snap b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_output_serializes.snap deleted file mode 100644 index 3849e5f268..0000000000 --- a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_output_serializes.snap +++ /dev/null @@ -1,8 +0,0 @@ ---- -source: coglet/src/bridge/protocol.rs -expression: resp ---- -{ - "type": "output", - "output": "chunk 1" -} diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index fceac88601..c0aec8b9ed 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -1031,14 +1031,6 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::BinaryChunk { mime_type, data }) => { - tracing::debug!( - %slot_id, - %mime_type, - bytes = data.len(), - "Ignoring binary chunk until WebSocket streaming is implemented" - ); - } Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); let bytes = match std::fs::read(&filename) { diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index 3869157297..c59a26cc7a 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -133,7 +133,7 @@ fn init_worker_tracing(tx: mpsc::Sender) { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ ControlRequest, ControlResponse, FileOutputKind, LogSource, MAX_INLINE_IPC_SIZE, MetricMode, - SlotId, SlotOutcome, SlotRequest, SlotResponse, + SLOT_RESPONSE_PROTOCOL_VERSION, SlotId, SlotOutcome, SlotRequest, SlotResponse, }; use crate::bridge::transport::{ChildTransportInfo, connect_transport}; use crate::orchestrator::HealthcheckResult; @@ -659,6 +659,19 @@ pub async fn run_worker( .map(|(id, w)| (id, Arc::new(tokio::sync::Mutex::new(w)))) .collect(); + // Send protocol version on each slot so the orchestrator can detect mismatches + for (slot_id, writer) in &slot_writers { + let mut w = writer.lock().await; + if let Err(e) = w + .send(SlotResponse::ProtocolVersion { + version: SLOT_RESPONSE_PROTOCOL_VERSION, + }) + .await + { + tracing::warn!(%slot_id, error = %e, "Failed to send protocol version"); + } + } + // Main event loop loop { tokio::select! {