Skip to content
Open
18 changes: 9 additions & 9 deletions crates/owhisper-client/src/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ impl WebSocketIO for ListenClientIO {
}
}

fn from_message(msg: Message) -> Option<Self::Output> {
fn decode(msg: Message) -> Result<Self::Output, hypr_ws::client::DecodeError> {
match msg {
Message::Text(text) => Some(text.to_string()),
_ => None,
Message::Text(text) => Ok(text.to_string()),
_ => Err(hypr_ws::client::DecodeError::UnsupportedType),
}
}
}
Expand Down Expand Up @@ -170,10 +170,10 @@ impl WebSocketIO for ListenClientDualIO {
}
}

fn from_message(msg: Message) -> Option<Self::Output> {
fn decode(msg: Message) -> Result<Self::Output, hypr_ws::client::DecodeError> {
match msg {
Message::Text(text) => Some(text.to_string()),
_ => None,
Message::Text(text) => Ok(text.to_string()),
_ => Err(hypr_ws::client::DecodeError::UnsupportedType),
}
}
}
Expand Down Expand Up @@ -207,7 +207,7 @@ impl<A: RealtimeSttAdapter> ListenClient<A> {
MixedMessage::Control(control) => TransformedInput::Control(control),
});

let (raw_stream, inner) = ws
let (raw_stream, inner, _send_task) = ws
.from_audio::<ListenClientIO, _>(self.initial_message, Box::pin(transformed_stream))
.await?;

Expand Down Expand Up @@ -262,7 +262,7 @@ impl<A: RealtimeSttAdapter> ListenClientDual<A> {
MixedMessage::Control(control) => TransformedDualInput::Control(control),
});

let (raw_stream, inner) = ws
let (raw_stream, inner, _send_task) = ws
.from_audio::<ListenClientDualIO, _>(self.initial_message, Box::pin(transformed_stream))
.await?;

Expand Down Expand Up @@ -302,7 +302,7 @@ impl<A: RealtimeSttAdapter> ListenClientDual<A> {
let spk_connect =
spk_ws.from_audio::<ListenClientIO, _>(self.initial_message, spk_outbound);

let ((mic_raw, mic_handle), (spk_raw, spk_handle)) =
let ((mic_raw, mic_handle, _mic_send_task), (spk_raw, spk_handle, _spk_send_task)) =
tokio::try_join!(mic_connect, spk_connect)?;

tokio::spawn(forward_dual_to_single(
Expand Down
2 changes: 1 addition & 1 deletion crates/transcribe-whisper-local/src/service/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ where
}
};

let guard = connection_manager.acquire_connection();
let guard = connection_manager.acquire_connection().await;

Ok(ws_upgrade
.on_upgrade(move |socket| async move {
Expand Down
4 changes: 1 addition & 3 deletions crates/ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2024"
[dependencies]
bytes = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }

async-stream = { workspace = true }
Expand All @@ -14,6 +15,3 @@ futures-util = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] }
tokio-tungstenite = { workspace = true, features = ["native-tls-vendored"] }
tracing = { workspace = true }

[dev-dependencies]
serde_json.workspace = true
133 changes: 93 additions & 40 deletions crates/ws-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,94 @@ use futures_util::{
};
use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest};

pub use crate::config::{ConnectionConfig, KeepAliveConfig, RetryConfig};
pub use tokio_tungstenite::tungstenite::{ClientRequestBuilder, Utf8Bytes, protocol::Message};

#[derive(Debug)]
enum ControlCommand {
Finalize(Option<Message>),
}

#[derive(Clone)]
struct KeepAliveConfig {
interval: std::time::Duration,
message: Message,
}

#[derive(Clone)]
pub struct WebSocketHandle {
control_tx: tokio::sync::mpsc::UnboundedSender<ControlCommand>,
}

impl WebSocketHandle {
pub async fn finalize_with_text(&self, text: Utf8Bytes) {
let _ = self
if self
.control_tx
.send(ControlCommand::Finalize(Some(Message::Text(text))));
.send(ControlCommand::Finalize(Some(Message::Text(text))))
.is_err()
{
tracing::warn!("control channel closed, cannot send finalize command");
}
}
}

pub struct SendTask {
handle: tokio::task::JoinHandle<Result<(), crate::Error>>,
}

impl SendTask {
pub async fn wait(self) -> Result<(), crate::Error> {
match self.handle.await {
Ok(result) => result,
Err(join_err) if join_err.is_panic() => {
std::panic::resume_unwind(join_err.into_panic());
}
Err(join_err) => {
tracing::error!("send task cancelled: {:?}", join_err);
Err(crate::Error::UnexpectedClose)
}
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("unsupported message type")]
UnsupportedType,

#[error("deserialization failed: {0}")]
DeserializationError(#[from] serde_json::Error),
}

pub trait WebSocketIO: Send + 'static {
type Data: Send;
type Input: Send;
type Output: DeserializeOwned;

fn to_input(data: Self::Data) -> Self::Input;
fn to_message(input: Self::Input) -> Message;
fn from_message(msg: Message) -> Option<Self::Output>;
fn decode(msg: Message) -> Result<Self::Output, DecodeError>;
}

pub struct WebSocketClient {
request: ClientRequestBuilder,
keep_alive: Option<KeepAliveConfig>,
config: ConnectionConfig,
}

impl WebSocketClient {
pub fn new(request: ClientRequestBuilder) -> Self {
Self {
request,
keep_alive: None,
config: ConnectionConfig::default(),
}
}

pub fn with_config(mut self, config: ConnectionConfig) -> Self {
self.config = config;
self
}

pub fn with_keep_alive(mut self, config: KeepAliveConfig) -> Self {
self.keep_alive = Some(config);
self
}

pub fn with_keep_alive_message(
mut self,
interval: std::time::Duration,
Expand All @@ -73,15 +112,18 @@ impl WebSocketClient {
(
impl Stream<Item = Result<T::Output, crate::Error>> + use<T, S>,
WebSocketHandle,
SendTask,
),
crate::Error,
> {
let keep_alive_config = self.keep_alive.clone();
let close_grace_period = self.config.close_grace_period;
let retry_config = self.config.retry_config.clone();
let ws_stream = (|| self.try_connect(self.request.clone()))
.retry(
ConstantBuilder::default()
.with_max_times(5)
.with_delay(std::time::Duration::from_millis(500)),
.with_max_times(retry_config.max_attempts)
.with_delay(retry_config.delay),
)
.when(|e| {
tracing::error!("ws_connect_failed: {:?}", e);
Expand All @@ -96,13 +138,17 @@ impl WebSocketClient {
let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel::<crate::Error>();
let handle = WebSocketHandle { control_tx };

let _send_task = tokio::spawn(async move {
let send_task = tokio::spawn(async move {
if let Some(msg) = initial_message
&& let Err(e) = ws_sender.send(msg).await
{
tracing::error!("ws_initial_message_failed: {:?}", e);
let _ = error_tx.send(e.into());
return;
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate error");
}
return Err(crate::Error::DataSend {
context: "initial message".to_string(),
});
}

let mut last_outbound_at = tokio::time::Instant::now();
Expand All @@ -120,7 +166,9 @@ impl WebSocketClient {
if let Some(cfg) = keep_alive_config.as_ref() {
if let Err(e) = ws_sender.send(cfg.message.clone()).await {
tracing::error!("ws_keepalive_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate keepalive error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -132,7 +180,9 @@ impl WebSocketClient {

if let Err(e) = ws_sender.send(msg).await {
tracing::error!("ws_send_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate send error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -141,7 +191,9 @@ impl WebSocketClient {
if let Some(msg) = maybe_msg {
if let Err(e) = ws_sender.send(msg).await {
tracing::error!("ws_finalize_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate finalize error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -151,36 +203,32 @@ impl WebSocketClient {
}
}

// Wait 5 seconds before closing the connection
// TODO: This might not be enough to ensure receiving remaining transcripts from the server.
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let _ = ws_sender.close().await;
tracing::debug!("draining remaining messages before close");
tokio::time::sleep(close_grace_period).await;
if let Err(e) = ws_sender.close().await {
tracing::debug!("ws_close_failed: {:?}", e);
}
Ok(())
});

let send_task_handle = SendTask { handle: send_task };

let output_stream = async_stream::stream! {
loop {
tokio::select! {
Some(msg_result) = ws_receiver.next() => {
match msg_result {
Ok(msg) => {
let is_text = matches!(msg, Message::Text(_));
let is_binary = matches!(msg, Message::Binary(_));
let text_preview = if let Message::Text(ref t) = msg {
Some(t.to_string())
} else {
None
};

match msg {
Message::Text(_) | Message::Binary(_) => {
if let Some(output) = T::from_message(msg) {
yield Ok(output);
} else if is_text {
if let Some(text) = text_preview {
tracing::warn!("ws_message_parse_failed: {}", text);
match T::decode(msg) {
Ok(output) => yield Ok(output),
Err(DecodeError::UnsupportedType) => {
tracing::debug!("ws_message_unsupported_type");
}
Err(DecodeError::DeserializationError(e)) => {
tracing::warn!("ws_message_parse_failed: {}", e);
}
} else if is_binary {
tracing::warn!("ws_binary_message_parse_failed");
}
},
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
Expand All @@ -207,7 +255,7 @@ impl WebSocketClient {
}
};

Ok((output_stream, handle))
Ok((output_stream, handle, send_task_handle))
}

async fn try_connect(
Expand All @@ -219,12 +267,17 @@ impl WebSocketClient {
>,
crate::Error,
> {
let req = req.into_client_request().unwrap();
let req = req
.into_client_request()
.map_err(|e| crate::Error::InvalidRequest(e.to_string()))?;

tracing::info!("connect_async: {:?}", req.uri());

let (ws_stream, _) =
tokio::time::timeout(std::time::Duration::from_secs(8), connect_async(req)).await??;
let timeout_duration = self.config.connect_timeout;
let (ws_stream, _) = tokio::time::timeout(timeout_duration, connect_async(req))
.await
.map_err(|e| crate::Error::timeout(e, timeout_duration))?
.map_err(crate::Error::Connection)?;

Ok(ws_stream)
}
Expand Down
40 changes: 40 additions & 0 deletions crates/ws-client/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::time::Duration;
use tokio_tungstenite::tungstenite::protocol::Message;

#[derive(Clone, Debug)]
pub struct ConnectionConfig {
pub connect_timeout: Duration,
pub retry_config: RetryConfig,
pub close_grace_period: Duration,
}

impl Default for ConnectionConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(8),
retry_config: RetryConfig::default(),
close_grace_period: Duration::from_secs(5),
}
}
}

#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_attempts: usize,
pub delay: Duration,
}

impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 5,
delay: Duration::from_millis(500),
}
}
}

#[derive(Clone, Debug)]
pub struct KeepAliveConfig {
pub interval: Duration,
pub message: Message,
}
Loading