Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/telemetry/trace_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
ATTR_TRANSCRIPTION_DELAY = "lk.transcription_delay"
ATTR_END_OF_TURN_DELAY = "lk.end_of_turn_delay"

# websocket connection
ATTR_WS_CONNECTION_TIME = "lk.ws.connection_time"

# metrics
ATTR_LLM_METRICS = "lk.llm_metrics"
ATTR_TTS_METRICS = "lk.tts_metrics"
Expand Down
3 changes: 2 additions & 1 deletion livekit-agents/livekit/agents/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from . import aio, audio, codecs, http_context, http_server, hw, images
from .audio import AudioBuffer, combine_frames, merge_frames
from .bounded_dict import BoundedDict
from .connection_pool import ConnectionPool
from .connection_pool import ConnectionPool, ConnectionResult
from .exp_filter import ExpFilter
from .log import log_exceptions
from .misc import is_given, nodename, shortuuid, time_ms
Expand Down Expand Up @@ -33,6 +33,7 @@
"hw",
"is_given",
"ConnectionPool",
"ConnectionResult",
"wait_for_agent",
"wait_for_participant",
"wait_for_track_publication",
Expand Down
57 changes: 55 additions & 2 deletions livekit-agents/livekit/agents/utils/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import weakref
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Generic, TypeVar

from ..log import logger
Expand All @@ -11,6 +12,19 @@
T = TypeVar("T")


@dataclass
class ConnectionResult(Generic[T]):
"""Result of getting a connection from the pool, including timing metadata."""

connection: T
connect_time: float
from_pool: bool

@property
def status(self) -> str:
return "reused" if self.from_pool else "new"


class ConnectionPool(Generic[T]):
"""Helper class to manage persistent connections like websockets.

Expand Down Expand Up @@ -89,14 +103,42 @@ async def connection(self, *, timeout: float) -> AsyncGenerator[T, None]:
else:
self.put(conn)

@asynccontextmanager
async def connection_with_timing(
self, *, timeout: float
) -> AsyncGenerator[ConnectionResult[T], None]:
"""Get a connection from the pool with timing metadata.

Yields:
A ConnectionResult containing the connection and timing information
"""
result = await self.get_with_timing(timeout=timeout)
try:
yield result
except BaseException:
self.remove(result.connection)
raise
else:
self.put(result.connection)

async def get(self, *, timeout: float) -> T:
"""Get an available connection or create a new one if needed.

Returns:
An active connection object
"""
result = await self.get_with_timing(timeout=timeout)
return result.connection

async def get_with_timing(self, *, timeout: float) -> ConnectionResult[T]:
"""Get an available connection or create a new one if needed, with timing metadata.

Returns:
A ConnectionResult containing the connection and timing information
"""
async with self._connect_lock:
await self._drain_to_close()
start_time = time.perf_counter()
now = time.time()

# try to reuse an available connection that hasn't expired
Expand All @@ -108,11 +150,22 @@ async def get(self, *, timeout: float) -> T:
):
if self._mark_refreshed_on_get:
self._connections[conn] = now
return conn
connect_time = time.perf_counter() - start_time
return ConnectionResult(
connection=conn,
connect_time=connect_time,
from_pool=True,
)
# connection expired; mark it for resetting.
self.remove(conn)

return await self._connect(timeout)
conn = await self._connect(timeout)
connect_time = time.perf_counter() - start_time
return ConnectionResult(
connection=conn,
connect_time=connect_time,
from_pool=False,
)

def put(self, conn: T) -> None:
"""Mark a connection as available for reuse.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
tts,
utils,
)
from livekit.agents.telemetry import trace_types
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
from livekit.agents.utils import is_given
from livekit.agents.voice.io import TimedString
Expand Down Expand Up @@ -514,7 +515,22 @@ async def _recv_task(ws: aiohttp.ClientWebSocketResponse, cartesia_context_id: s

cartesia_context_id = utils.shortuuid()
try:
async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
async with self._tts._pool.connection_with_timing(
timeout=self._conn_options.timeout
) as conn_result:
ws = conn_result.connection
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, conn_result.connect_time
)
logger.debug(
"Cartesia TTS WebSocket connected (%s)",
conn_result.status,
extra={
"connection_time": conn_result.connect_time,
"cartesia_context_id": cartesia_context_id,
},
)
tasks = [
asyncio.create_task(_input_task()),
asyncio.create_task(_sentence_stream_task(ws, cartesia_context_id)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import json
import os
import time
import weakref
from collections import Counter
from collections.abc import Sequence
Expand All @@ -37,6 +38,7 @@
stt,
utils,
)
from livekit.agents.telemetry import trace_types
from livekit.agents.types import (
NOT_GIVEN,
NotGivenOr,
Expand Down Expand Up @@ -623,6 +625,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._opts.tags:
live_config["tag"] = self._opts.tags

start_time = time.perf_counter()
try:
ws = await asyncio.wait_for(
self._session.ws_connect(
Expand All @@ -631,12 +634,17 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
),
self._conn_options.timeout,
)
ws_connection_time = time.perf_counter() - start_time
ws_headers = {
k: v for k, v in ws._response.headers.items() if k.startswith("dg-") or k == "Date"
}
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, ws_connection_time
)
logger.debug(
"Established new Deepgram STT WebSocket connection:",
extra={"headers": ws_headers},
"Deepgram STT WebSocket connected (new)",
extra={"headers": ws_headers, "connection_time": ws_connection_time},
)
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
raise APIConnectionError("failed to connect to deepgram") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
tts,
utils,
)
from livekit.agents.telemetry import trace_types
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
Expand Down Expand Up @@ -349,7 +350,19 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
else:
logger.debug("Unknown message type: %s", resp)

async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
async with self._tts._pool.connection_with_timing(
timeout=self._conn_options.timeout
) as conn_result:
ws = conn_result.connection
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, conn_result.connect_time
)
logger.debug(
"Deepgram TTS WebSocket connected (%s)",
conn_result.status,
extra={"connection_time": conn_result.connect_time, "segment_id": segment_id},
)
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dataclasses
import json
import os
import time
import weakref
from dataclasses import dataclass, replace
from functools import cached_property
Expand All @@ -37,6 +38,7 @@
tts,
utils,
)
from livekit.agents.telemetry import trace_types
from livekit.agents.tokenize.basic import split_words
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
from livekit.agents.utils import is_given
Expand Down Expand Up @@ -270,21 +272,26 @@ def update_options(
self._current_connection.mark_non_current()
self._current_connection = None

async def current_connection(self) -> _Connection:
"""Get the current connection, creating one if needed"""
async def current_connection(self) -> tuple[_Connection, bool]:
"""Get the current connection, creating one if needed.

Returns:
A tuple of (_Connection, is_reused) where is_reused is True if the
connection was reused from the pool.
"""
async with self._connection_lock:
if (
self._current_connection
and self._current_connection.is_current
and not self._current_connection._closed
):
return self._current_connection
return self._current_connection, True # reused connection

session = self._ensure_session()
conn = _Connection(self._opts, session)
await conn.connect()
self._current_connection = conn
return conn
return conn, False # new connection

def synthesize(
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
Expand Down Expand Up @@ -401,9 +408,22 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:

connection: _Connection
try:
connection = await asyncio.wait_for(
start_time = time.perf_counter()
connection, is_reused = await asyncio.wait_for(
self._tts.current_connection(), self._conn_options.timeout
)
total_time = time.perf_counter() - start_time
ws_connection_time = connection._connect_time or total_time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 or operator used instead of is not None for _connect_time fallback, incorrect for 0.0

On line 416, connection._connect_time or total_time uses Python's truthiness to decide between the original WS connect time and the fallback. Since _connect_time is float | None, if _connect_time were exactly 0.0, the or would incorrectly fall through to total_time. The correct pattern is connection._connect_time if connection._connect_time is not None else total_time. While 0.0 is practically impossible for a real WS handshake, this is a known anti-pattern with numeric types.

Suggested change
ws_connection_time = connection._connect_time or total_time
ws_connection_time = connection._connect_time if connection._connect_time is not None else total_time
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

status = "reused" if is_reused else "new"
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, ws_connection_time
)
logger.debug(
"ElevenLabs TTS WebSocket connected (%s)",
status,
extra={"connection_time": ws_connection_time, "context_id": self._context_id},
)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except Exception as e:
Expand Down Expand Up @@ -539,6 +559,9 @@ def __init__(self, opts: _TTSOptions, session: aiohttp.ClientSession):
self._recv_task: asyncio.Task | None = None
self._closed = False

# WebSocket connection timing
self._connect_time: float | None = None

@property
def voice_id(self) -> str:
return self._opts.voice_id
Expand Down Expand Up @@ -573,7 +596,13 @@ async def connect(self) -> None:

url = _multi_stream_url(self._opts)
headers = {AUTHORIZATION_HEADER: self._opts.api_key}
start_time = time.perf_counter()
self._ws = await self._session.ws_connect(url, headers=headers)
self._connect_time = time.perf_counter() - start_time
logger.debug(
"established ElevenLabs TTS WebSocket connection",
extra={"connection_time": self._connect_time},
)

self._send_task = asyncio.create_task(self._send_loop())
self._recv_task = asyncio.create_task(self._recv_loop())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from livekit.agents import APIConnectionError, LanguageCode, llm, utils
from livekit.agents.metrics import RealtimeModelMetrics
from livekit.agents.metrics.base import Metadata
from livekit.agents.telemetry import trace_types
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
Expand Down Expand Up @@ -789,9 +790,19 @@ async def _main_task(self) -> None:
session = None
try:
logger.debug("connecting to Gemini Realtime API...")
connect_start_time = time.perf_counter()
async with self._client.aio.live.connect(
model=self._opts.model, config=config
) as session:
ws_connection_time = time.perf_counter() - connect_start_time
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, ws_connection_time
)
logger.debug(
"Gemini Realtime API WebSocket connected (new)",
extra={"connection_time": ws_connection_time},
)
async with self._session_lock:
self._active_session = session

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
stt,
utils,
)
from livekit.agents.telemetry import trace_types
from livekit.agents.types import (
NOT_GIVEN,
NotGivenOr,
Expand Down Expand Up @@ -822,7 +823,19 @@ async def process_stream(
while True:
audio_pushed = False
try:
async with self._pool.connection(timeout=self._conn_options.timeout) as client:
async with self._pool.connection_with_timing(
timeout=self._conn_options.timeout
) as conn_result:
client = conn_result.connection
from opentelemetry import trace
trace.get_current_span().set_attribute(
trace_types.ATTR_WS_CONNECTION_TIME, conn_result.connect_time
)
logger.debug(
"Google STT connected (%s)",
conn_result.status,
extra={"connection_time": conn_result.connect_time},
)
self._streaming_config = self._build_streaming_config()

should_stop = asyncio.Event()
Expand Down
Loading
Loading