Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 234 additions & 29 deletions livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
import weakref
from dataclasses import dataclass, replace
from urllib.parse import urlencode

import aiohttp

from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIError,
APIStatusError,
APITimeoutError,
tokenize,
tts,
utils,
)
Expand All @@ -34,6 +40,7 @@
NotGivenOr,
)
from livekit.agents.utils import is_given
from livekit.agents.voice.io import TimedString

from .langs import TTSLangs
from .log import logger
Expand All @@ -43,6 +50,8 @@
ARCANA_MODEL_TIMEOUT = 60 * 4
MIST_MODEL_TIMEOUT = 30
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts"
RIME_WS_BASE_URL = "wss://users-ws.rime.ai"
NUM_CHANNELS = 1


@dataclass
Expand Down Expand Up @@ -73,9 +82,6 @@ class _MistOptions:
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN


NUM_CHANNELS = 1


def _is_mist_model(model: TTSModels | str) -> bool:
return "mist" in model

Expand All @@ -86,11 +92,40 @@ def _timeout_for_model(model: TTSModels | str) -> int:
return MIST_MODEL_TIMEOUT


def _model_params(opts: _TTSOptions) -> dict[str, object]:
"""Per-model option fields shared between the HTTP body and the WS query string."""
params: dict[str, object] = {}
if opts.model == "arcana" and opts.arcana_options is not None:
ao = opts.arcana_options
if is_given(ao.lang):
params["lang"] = ao.lang
if is_given(ao.repetition_penalty):
params["repetition_penalty"] = ao.repetition_penalty
if is_given(ao.temperature):
params["temperature"] = ao.temperature
if is_given(ao.top_p):
params["top_p"] = ao.top_p
if is_given(ao.max_tokens):
params["max_tokens"] = ao.max_tokens
elif _is_mist_model(opts.model) and opts.mist_options is not None:
mo = opts.mist_options
if is_given(mo.lang):
params["lang"] = mo.lang
if is_given(mo.speed_alpha):
params["speedAlpha"] = mo.speed_alpha
if is_given(mo.pause_between_brackets):
params["pauseBetweenBrackets"] = mo.pause_between_brackets
if is_given(mo.phonemize_between_brackets):
params["phonemizeBetweenBrackets"] = mo.phonemize_between_brackets
return params


class TTS(tts.TTS):
def __init__(
self,
*,
base_url: str = RIME_BASE_URL,
ws_base_url: str = RIME_WS_BASE_URL,
model: TTSModels | str = "arcana",
speaker: NotGivenOr[ArcanaVoices | str] = NOT_GIVEN,
lang: TTSLangs | str = "eng",
Expand All @@ -107,10 +142,14 @@ def __init__(
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
use_websocket: bool = False,
segment: NotGivenOr[str] = NOT_GIVEN,
tokenizer: NotGivenOr[tokenize.SentenceTokenizer] = NOT_GIVEN,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
streaming=use_websocket,
aligned_transcript=use_websocket,
),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
Expand Down Expand Up @@ -148,9 +187,23 @@ def __init__(
)
self._session = http_session
self._base_url = base_url
self._ws_base_url = ws_base_url
self._use_websocket = use_websocket
self._segment = segment if is_given(segment) else "bySentence"

self._total_timeout = _timeout_for_model(model)

self._streams: weakref.WeakSet[SynthesizeStream] = weakref.WeakSet()
self._sentence_tokenizer = (
tokenizer if is_given(tokenizer) else tokenize.blingfire.SentenceTokenizer()
)
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
max_session_duration=300,
mark_refreshed_on_get=True,
)

@property
def model(self) -> str:
return self._opts.model
Expand All @@ -165,6 +218,61 @@ def _ensure_session(self) -> aiohttp.ClientSession:

return self._session

def _ws_url(self) -> str:
params: dict[str, object] = {
"speaker": self._opts.speaker,
"modelId": self._opts.model,
"audioFormat": "pcm",
"samplingRate": self._sample_rate,
"segment": self._segment,
**_model_params(self._opts),
}
encoded = {
k: ("true" if v else "false") if isinstance(v, bool) else v for k, v in params.items()
}
return f"{self._ws_base_url}/ws3?{urlencode(encoded)}"

async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
session = self._ensure_session()
return await asyncio.wait_for(
session.ws_connect(
self._ws_url(), headers={"Authorization": f"Bearer {self._api_key}"}
),
timeout,
)

async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
try:
await ws.send_str(json.dumps({"operation": "eos"}))
try:
await asyncio.wait_for(ws.receive(), timeout=1.0)
except asyncio.TimeoutError:
pass
except Exception as e:
logger.warning(f"Error during Rime WS close sequence: {e}")
finally:
await ws.close()

def prewarm(self) -> None:
self._pool.prewarm()

def stream(
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> SynthesizeStream:
if not self._use_websocket:
raise RuntimeError(
"Rime TTS streaming requires use_websocket=True at construction time"
)
s = SynthesizeStream(tts=self, conn_options=conn_options)
self._streams.add(s)
return s

async def aclose(self) -> None:
for s in list(self._streams):
await s.aclose()
self._streams.clear()
await self._pool.aclose()

def synthesize(
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> ChunkedStream:
Expand All @@ -189,6 +297,8 @@ def update_options(
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
) -> None:
# WS URL is bound at pool connect; invalidate if any URL-affecting param changed.
prev_ws_url = self._ws_url() if self._use_websocket else None
if is_given(base_url):
self._base_url = base_url
if is_given(model):
Expand Down Expand Up @@ -231,6 +341,9 @@ def update_options(
if is_given(phonemize_between_brackets):
self._opts.mist_options.phonemize_between_brackets = phonemize_between_brackets

if prev_ws_url is not None and self._ws_url() != prev_ws_url:
self._pool.invalidate()


class ChunkedStream(tts.ChunkedStream):
"""Synthesize using the chunked api endpoint"""
Expand All @@ -245,38 +358,18 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
"speaker": self._opts.speaker,
"text": self._input_text,
"modelId": self._opts.model,
**_model_params(self._opts),
}
format = "audio/pcm"
if self._opts.model == "arcana":
arcana_opts = self._opts.arcana_options
assert arcana_opts is not None
if is_given(arcana_opts.repetition_penalty):
payload["repetition_penalty"] = arcana_opts.repetition_penalty
if is_given(arcana_opts.temperature):
payload["temperature"] = arcana_opts.temperature
if is_given(arcana_opts.top_p):
payload["top_p"] = arcana_opts.top_p
if is_given(arcana_opts.max_tokens):
payload["max_tokens"] = arcana_opts.max_tokens
if is_given(arcana_opts.lang):
payload["lang"] = arcana_opts.lang
if is_given(arcana_opts.sample_rate):
payload["samplingRate"] = arcana_opts.sample_rate
elif _is_mist_model(self._opts.model):
if self._opts.model == "arcana" and self._opts.arcana_options is not None:
if is_given(self._opts.arcana_options.sample_rate):
payload["samplingRate"] = self._opts.arcana_options.sample_rate
elif _is_mist_model(self._opts.model) and self._opts.mist_options is not None:
mist_opts = self._opts.mist_options
assert mist_opts is not None
if is_given(mist_opts.lang):
payload["lang"] = mist_opts.lang
if is_given(mist_opts.sample_rate):
payload["samplingRate"] = mist_opts.sample_rate
if is_given(mist_opts.speed_alpha):
payload["speedAlpha"] = mist_opts.speed_alpha
if self._opts.model == "mistv2" and is_given(mist_opts.reduce_latency):
payload["reduceLatency"] = mist_opts.reduce_latency
if is_given(mist_opts.pause_between_brackets):
payload["pauseBetweenBrackets"] = mist_opts.pause_between_brackets
if is_given(mist_opts.phonemize_between_brackets):
payload["phonemizeBetweenBrackets"] = mist_opts.phonemize_between_brackets

try:
async with self._tts._ensure_session().post(
Expand Down Expand Up @@ -316,3 +409,115 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
) from None
except Exception as e:
raise APIConnectionError() from e


class SynthesizeStream(tts.SynthesizeStream):
"""One stream = one utterance. Server-side bySentence segmentation by default;
pass segment="immediate" on the TTS to disable server buffering when the agent
is already feeding sentence-tokenized text."""

def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
request_id = utils.shortuuid()
context_id = utils.shortuuid()
output_emitter.initialize(
request_id=request_id,
sample_rate=self._tts.sample_rate,
num_channels=NUM_CHANNELS,
mime_type="audio/pcm",
stream=True,
)
output_emitter.start_segment(segment_id=context_id)

sent_stream = self._tts._sentence_tokenizer.stream()
input_sent_event = asyncio.Event()
empty_input = False

async def _input_task() -> None:
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
sent_stream.flush()
continue
sent_stream.push_text(data)
sent_stream.end_input()

async def _send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
nonlocal empty_input
sent_count = 0
async for ev in sent_stream:
pkt = {"text": ev.token + " ", "contextId": context_id}
self._mark_started()
await ws.send_str(json.dumps(pkt))
input_sent_event.set()
sent_count += 1
if sent_count == 0:
empty_input = True
input_sent_event.set()
output_emitter.end_input()
return
await ws.send_str(json.dumps({"operation": "flush", "contextId": context_id}))

async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
await input_sent_event.wait()
if empty_input:
return
while True:
msg = await ws.receive(timeout=self._conn_options.timeout)
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"Rime ws closed unexpectedly",
request_id=request_id,
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected Rime ws message type %s", msg.type)
continue
data = json.loads(msg.data)
t = data.get("type")
if t == "chunk":
output_emitter.push(base64.b64decode(data["data"]))
elif t == "timestamps":
wt = data.get("word_timestamps") or {}
words = wt.get("words") or []
starts = wt.get("start") or []
ends = wt.get("end") or []
for w, s, e in zip(words, starts, ends, strict=False):
output_emitter.push_timed_transcript(
TimedString(text=w + " ", start_time=s, end_time=e)
)
elif t == "done":
output_emitter.end_input()
break
elif t == "error":
msg_text = data.get("message", "(no message)")
raise APIError(f"Rime ws error: {msg_text}")

try:
async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
tasks = [
asyncio.create_task(_input_task()),
asyncio.create_task(_send_task(ws)),
asyncio.create_task(_recv_task(ws)),
]
try:
await asyncio.gather(*tasks)
finally:
input_sent_event.set()
await sent_stream.aclose()
await utils.aio.gracefully_cancel(*tasks)
except asyncio.TimeoutError:
raise APITimeoutError() from None
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message, status_code=e.status, request_id=None, body=None
) from None
except APIError:
raise
except Exception as e:
raise APIConnectionError(f"Rime WS error: {e}") from e
7 changes: 7 additions & 0 deletions tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,13 @@ async def test_tts_synthesize_error_propagation():
},
id="google",
),
pytest.param(
lambda: {
"tts": rime.TTS(use_websocket=True),
"proxy-upstream": "users-ws.rime.ai:443",
},
id="rime",
),
pytest.param(
lambda: {
"tts": tts.StreamAdapter(tts=inworld.TTS()),
Expand Down
Loading