diff --git a/.gitignore b/.gitignore index 4420978f..e9ef7669 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,7 @@ creds/ # Claude Code .claude/ claude_only_docs/ +CLAUDE.md # Local folders local/ diff --git a/configs/prompts/simulation.yaml b/configs/prompts/simulation.yaml index 9dc36cb6..6914346a 100644 --- a/configs/prompts/simulation.yaml +++ b/configs/prompts/simulation.yaml @@ -125,7 +125,7 @@ audio_llm_agent: realtime_agent: system_prompt: | - You are a friendly voice assistant. + You are an AI voice assistant on a live phone call. Call the appropriate function to process the user's input. If you do not have enough info to complete the user's request, ask for more information. Call the tool as many times as you need until the user's task is complete. Call the tool as quickly as possible. diff --git a/docs/metric_context.md b/docs/metric_context.md index dbcabb06..62778402 100644 --- a/docs/metric_context.md +++ b/docs/metric_context.md @@ -282,9 +282,9 @@ Minor discrepancies are still possible (turn IDs off by one, audio timestamps no Benchmark Execution: ├─ EvaluationRecord (dataset.jsonl) │ ├─ user_goal, user_persona, scenario_db → MetricContext - │ └─ Feeds to AssistantServer + UserSimulator + │ └─ Feeds to PipecatAssistantServer + UserSimulator │ - ├─ AssistantServer writes: + ├─ PipecatAssistantServer writes: │ ├─ audit_log.json (tool calls, user/assistant turns) │ ├─ pipecat_events.jsonl (TTS text, turn boundaries) │ ├─ response_latencies.json (response speed data) diff --git a/src/eva/assistant/audio_bridge.py b/src/eva/assistant/audio_bridge.py new file mode 100644 index 00000000..8051e5aa --- /dev/null +++ b/src/eva/assistant/audio_bridge.py @@ -0,0 +1,267 @@ +"""Shared audio bridge utilities for framework-specific assistant servers. + +All framework servers need to: +1. Accept Twilio-framed WebSocket connections from the user simulator +2. Convert audio between Twilio's mulaw 8kHz and the framework's native format +3. Write framework_logs.jsonl with timestamped events + +This module provides the common infrastructure. +""" + +import audioop +import base64 +import json +import struct +import time +from pathlib import Path +from typing import Optional + +import numpy as np +import soxr + +from eva.utils.logging import get_logger + +logger = get_logger(__name__) + + +# ── Audio format conversion ────────────────────────────────────────── + + +def mulaw_8k_to_pcm16_16k(mulaw_bytes: bytes) -> bytes: + """Convert 8kHz mu-law audio to 16kHz 16-bit PCM.""" + # Decode mu-law to 16-bit PCM at 8kHz + pcm_8k = audioop.ulaw2lin(mulaw_bytes, 2) + # Upsample from 8kHz to 16kHz + pcm_16k, _ = audioop.ratecv(pcm_8k, 2, 1, 8000, 16000, None) + return pcm_16k + + +def mulaw_8k_to_pcm16_24k(mulaw_bytes: bytes) -> bytes: + """Convert 8kHz mu-law audio to 24kHz 16-bit PCM.""" + # Decode mu-law to 16-bit PCM at 8kHz + pcm_8k = audioop.ulaw2lin(mulaw_bytes, 2) + # Upsample from 8kHz to 24kHz + pcm_24k, _ = audioop.ratecv(pcm_8k, 2, 1, 8000, 24000, None) + return pcm_24k + + +def pcm16_16k_to_mulaw_8k(pcm_bytes: bytes) -> bytes: + """Convert 16kHz 16-bit PCM to 8kHz mu-law.""" + # Downsample from 16kHz to 8kHz + pcm_8k, _ = audioop.ratecv(pcm_bytes, 2, 1, 16000, 8000, None) + # Encode to mu-law + return audioop.lin2ulaw(pcm_8k, 2) + + +def pcm16_24k_to_mulaw_8k(pcm_bytes: bytes) -> bytes: + """Convert 24kHz 16-bit PCM to 8kHz mu-law. + + Uses soxr VHQ resampling (same as Pipecat) for proper anti-aliasing during the 3:1 downsampling. + audioop.ratecv produces muffled audio because it lacks an anti-aliasing filter. + """ + # Downsample from 24kHz to 8kHz using high-quality resampler + audio_data = np.frombuffer(pcm_bytes, dtype=np.int16) + resampled = soxr.resample(audio_data, 24000, 8000, quality="VHQ") + pcm_8k = resampled.astype(np.int16).tobytes() + # Encode to mu-law + return audioop.lin2ulaw(pcm_8k, 2) + + +def sync_buffer_to_position(buffer: bytearray, target_position: int) -> None: + """Pad *buffer* with silence bytes so it reaches *target_position*. + + Mirrors pipecat's ``AudioBufferProcessor._sync_buffer_to_position``. + Call this **before** extending the *other* track so both tracks stay + positionally aligned. + """ + current_len = len(buffer) + if current_len < target_position: + buffer.extend(b"\x00" * (target_position - current_len)) + + +def pcm16_mix(track_a: bytes, track_b: bytes) -> bytes: + """Mix two 16-bit PCM tracks by sample-wise addition with clipping. + + Both tracks must be the same sample rate. If lengths differ, + the shorter track is zero-padded. + """ + len_a, len_b = len(track_a), len(track_b) + max_len = max(len_a, len_b) + + # Zero-pad shorter track + if len_a < max_len: + track_a = track_a + b"\x00" * (max_len - len_a) + if len_b < max_len: + track_b = track_b + b"\x00" * (max_len - len_b) + + # Mix with clipping + n_samples = max_len // 2 + fmt = f"<{n_samples}h" + samples_a = struct.unpack(fmt, track_a) + samples_b = struct.unpack(fmt, track_b) + mixed = struct.pack(fmt, *(max(-32768, min(32767, a + b)) for a, b in zip(samples_a, samples_b))) + return mixed + + +# ── Twilio WebSocket Protocol ──────────────────────────────────────── + + +def parse_twilio_media_message(message: str) -> Optional[bytes]: + """Parse a Twilio media WebSocket message and extract raw audio bytes. + + Returns None if the message is not a media message. + """ + try: + data = json.loads(message) + if data.get("event") == "media": + payload = data["media"]["payload"] + return base64.b64decode(payload) + except (json.JSONDecodeError, KeyError): + pass + return None + + +def create_twilio_media_message(stream_sid: str, audio_bytes: bytes) -> str: + """Create a Twilio media WebSocket message with the given audio bytes.""" + payload = base64.b64encode(audio_bytes).decode("ascii") + return json.dumps( + { + "event": "media", + "streamSid": stream_sid, + "media": { + "payload": payload, + }, + } + ) + + +def create_twilio_start_response(stream_sid: str) -> str: + """Create a Twilio 'start' event response.""" + return json.dumps( + { + "event": "start", + "streamSid": stream_sid, + "start": { + "streamSid": stream_sid, + "mediaFormat": { + "encoding": "audio/x-mulaw", + "sampleRate": 8000, + "channels": 1, + }, + }, + } + ) + + +# ── Framework Logs Writer ──────────────────────────────────────────── + + +class FrameworkLogWriter: + """Writes framework_logs.jsonl (replacement for pipecat_logs.jsonl). + + Captures turn boundaries, TTS text, and LLM responses with accurate + wall-clock timestamps. + """ + + def __init__(self, output_dir: Path): + self.log_file = output_dir / "framework_logs.jsonl" + output_dir.mkdir(parents=True, exist_ok=True) + + def write(self, event_type: str, data: dict, timestamp_ms: Optional[int] = None) -> None: + """Write a single log entry. + + Args: + event_type: One of 'turn_start', 'turn_end', 'tts_text', 'llm_response' + data: Event data dict. Must contain a 'frame' key for tts_text/llm_response. + timestamp_ms: Wall-clock timestamp in milliseconds. Defaults to now. + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + entry = { + "timestamp": timestamp_ms, + "type": event_type, + "data": data, + } + try: + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + except Exception as e: + logger.error(f"Error writing framework log: {e}") + + def turn_start(self, timestamp_ms: Optional[int] = None) -> None: + """Log a turn start event.""" + self.write("turn_start", {"frame": "turn_start"}, timestamp_ms) + + def turn_end(self, was_interrupted: bool = False, timestamp_ms: Optional[int] = None) -> None: + """Log a turn end event.""" + self.write("turn_end", {"frame": "turn_end", "was_interrupted": was_interrupted}, timestamp_ms) + + def tts_text(self, text: str, timestamp_ms: Optional[int] = None) -> None: + """Log TTS text (what was actually spoken).""" + self.write("tts_text", {"frame": text}, timestamp_ms) + + def llm_response(self, text: str, timestamp_ms: Optional[int] = None) -> None: + """Log LLM response text (full intended response).""" + self.write("llm_response", {"frame": text}, timestamp_ms) + + +# ── Metrics Log Writer ─────────────────────────────────────────────── + + +class MetricsLogWriter: + """Writes pipecat_metrics.jsonl equivalent for non-pipecat frameworks.""" + + def __init__(self, output_dir: Path): + self.log_file = output_dir / "pipecat_metrics.jsonl" + output_dir.mkdir(parents=True, exist_ok=True) + + def write_processing_metric(self, processor: str, value_seconds: float, model: str = "") -> None: + """Write a ProcessingMetricsData entry (e.g., for STT latency).""" + entry = { + "timestamp": int(time.time() * 1000), + "type": "ProcessingMetricsData", + "processor": processor, + "model": model, + "value": value_seconds, + } + self._append(entry) + + def write_ttfb_metric(self, processor: str, value_seconds: float, model: str = "") -> None: + """Write a TTFBMetricsData entry (e.g., for TTS time-to-first-byte).""" + entry = { + "timestamp": int(time.time() * 1000), + "type": "TTFBMetricsData", + "processor": processor, + "model": model, + "value": value_seconds, + } + self._append(entry) + + def write_token_usage( + self, + processor: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + ) -> None: + """Write an LLMTokenUsageMetricsData entry.""" + entry = { + "timestamp": int(time.time() * 1000), + "type": "LLMTokenUsageMetricsData", + "processor": processor, + "model": model, + "value": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + self._append(entry) + + def _append(self, entry: dict) -> None: + try: + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(entry) + "\n") + except Exception as e: + logger.error(f"Error writing metrics log: {e}") diff --git a/src/eva/assistant/base_server.py b/src/eva/assistant/base_server.py new file mode 100644 index 00000000..641d708b --- /dev/null +++ b/src/eva/assistant/base_server.py @@ -0,0 +1,232 @@ +"""Abstract base class for assistant server implementations. + +All framework-specific assistant servers (Pipecat, OpenAI Realtime, Gemini Live, etc.) +must inherit from AbstractAssistantServer and implement the required interface. + +See docs/assistant_server_contract.md for the full specification. +""" + +import json +import wave +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +from eva.assistant.agentic.audit_log import AuditLog +from eva.assistant.tools.tool_executor import ToolExecutor +from eva.models.agents import AgentConfig +from eva.models.config import AudioLLMConfig, PipelineConfig, SpeechToSpeechConfig +from eva.utils.logging import get_logger + +logger = get_logger(__name__) + +INITIAL_MESSAGE = "Hello! How can I help you today?" + + +class AbstractAssistantServer(ABC): + """Base class for all assistant server implementations. + + Each implementation must: + 1. Expose a WebSocket endpoint at ws://localhost:{port}/ws with Twilio frame format + 2. Bridge audio between the user simulator and the framework's native format + 3. Execute tool calls via the local ToolExecutor + 4. Produce all required output files (audit_log.json, framework_logs.jsonl, audio, etc.) + 5. Populate the AuditLog with conversation events + """ + + def __init__( + self, + current_date_time: str, + pipeline_config: PipelineConfig | SpeechToSpeechConfig | AudioLLMConfig, + agent: AgentConfig, + agent_config_path: str, + scenario_db_path: str, + output_dir: Path, + port: int, + conversation_id: str, + ): + """Initialize the assistant server. + + Args: + current_date_time: Current date/time string from the evaluation record + pipeline_config: Configuration for the model/pipeline + agent: Single agent configuration to use + agent_config_path: Path to agent YAML configuration + scenario_db_path: Path to scenario database JSON + output_dir: Directory for output files + port: Port to listen on + conversation_id: Unique ID for this conversation + """ + self.current_date_time = current_date_time + self.pipeline_config = pipeline_config + self.agent: AgentConfig = agent + self.agent_config_path = agent_config_path + self.scenario_db_path = scenario_db_path + self.output_dir = Path(output_dir) + self.port = port + self.conversation_id = conversation_id + + # Core components - all implementations must use these + self.audit_log = AuditLog() + self.tool_handler = ToolExecutor( + tool_config_path=agent_config_path, + scenario_db_path=scenario_db_path, + tool_module_path=self.agent.tool_module_path, + current_date_time=self.current_date_time, + ) + + # Audio buffers for recording + self._audio_buffer = bytearray() + self.user_audio_buffer = bytearray() + self.assistant_audio_buffer = bytearray() + self._audio_sample_rate: int = 24000 # Subclasses can override + + # Latency tracking + self._latency_measurements: list[float] = [] + + @abstractmethod + async def start(self) -> None: + """Start the server. + + Must be non-blocking (return after the server is ready to accept connections). + Must expose a WebSocket endpoint at ws://localhost:{port}/ws using FastAPI+uvicorn + with TwilioFrameSerializer for compatibility with the user simulator. + + The implementation must: + 1. Create a FastAPI app with /ws and / WebSocket endpoints + 2. Start a uvicorn server on the configured port + 3. Return once the server is accepting connections + """ + ... + + @abstractmethod + async def stop(self) -> None: + """Stop the server and save all outputs. + + Must: + 1. Gracefully shut down the server + 2. Call save_outputs() to persist all data + """ + ... + + def get_conversation_stats(self) -> dict[str, Any]: + """Get statistics about the conversation. + + Returns dict with: num_turns, num_tool_calls, tools_called, etc. + """ + return self.audit_log.get_stats() + + def get_initial_scenario_db(self) -> dict[str, Any]: + """Get initial (pristine) scenario database state.""" + return self.tool_handler.original_db + + def get_final_scenario_db(self) -> dict[str, Any]: + """Get final (mutated) scenario database state.""" + return self.tool_handler.db + + # ── Shared output helpers ────────────────────────────────────────── + + async def save_outputs(self) -> None: + """Save all required output files. Called by stop(). + + Subclasses can override to add framework-specific outputs, + but must call super().save_outputs(). + """ + # Save audit log + self.audit_log.save(self.output_dir / "audit_log.json") + + # Save simplified transcript + transcript_path = self.output_dir / "transcript.jsonl" + self.audit_log.save_transcript_jsonl(transcript_path) + + # Save audio recordings + self._save_audio() + + # Save scenario database states (REQUIRED for deterministic metrics) + self._save_scenario_dbs() + + # Save response latencies + self._save_response_latencies() + + logger.info(f"Outputs saved to {self.output_dir}") + + def _save_audio(self) -> None: + """Save accumulated audio buffers to WAV files. + + If _audio_buffer (mixed) is empty but user and assistant buffers are + available, compute mixed audio automatically via sample-wise addition. + """ + # Auto-compute mixed audio from user + assistant tracks when not populated + if not self._audio_buffer and self.user_audio_buffer and self.assistant_audio_buffer: + from eva.assistant.audio_bridge import pcm16_mix + + self._audio_buffer = bytearray(pcm16_mix(bytes(self.user_audio_buffer), bytes(self.assistant_audio_buffer))) + elif not self._audio_buffer and self.user_audio_buffer: + self._audio_buffer = bytearray(self.user_audio_buffer) + elif not self._audio_buffer and self.assistant_audio_buffer: + self._audio_buffer = bytearray(self.assistant_audio_buffer) + + if self._audio_buffer: + self._save_wav_file( + bytes(self._audio_buffer), + self.output_dir / "audio_mixed.wav", + self._audio_sample_rate, + 1, + ) + if self.user_audio_buffer: + self._save_wav_file( + bytes(self.user_audio_buffer), + self.output_dir / "audio_user.wav", + self._audio_sample_rate, + 1, + ) + if self.assistant_audio_buffer: + self._save_wav_file( + bytes(self.assistant_audio_buffer), + self.output_dir / "audio_assistant.wav", + self._audio_sample_rate, + 1, + ) + + def _save_wav_file(self, audio_data: bytes, file_path: Path, sample_rate: int, num_channels: int) -> None: + """Save raw 16-bit PCM audio data to a WAV file.""" + try: + with wave.open(str(file_path), "wb") as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(2) # 16-bit PCM + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_data) + logger.debug(f"Audio saved to {file_path} ({len(audio_data)} bytes)") + except Exception as e: + logger.error(f"Error saving audio to {file_path}: {e}") + + def _save_scenario_dbs(self) -> None: + """Save initial and final scenario database states.""" + try: + initial_db_path = self.output_dir / "initial_scenario_db.json" + with open(initial_db_path, "w") as f: + json.dump(self.get_initial_scenario_db(), f, indent=2, sort_keys=True, default=str) + + final_db_path = self.output_dir / "final_scenario_db.json" + with open(final_db_path, "w") as f: + json.dump(self.get_final_scenario_db(), f, indent=2, sort_keys=True, default=str) + + logger.info(f"Saved scenario database states to {self.output_dir}") + except Exception as e: + logger.error(f"Error saving scenario database states: {e}", exc_info=True) + raise + + def _save_response_latencies(self) -> None: + """Save response latency measurements.""" + if not self._latency_measurements: + return + + latency_data = { + "latencies": self._latency_measurements, + "mean": sum(self._latency_measurements) / len(self._latency_measurements), + "max": max(self._latency_measurements), + "count": len(self._latency_measurements), + } + latency_path = self.output_dir / "response_latencies.json" + with open(latency_path, "w") as f: + json.dump(latency_data, f, indent=2) diff --git a/src/eva/assistant/gemini_live_server.py b/src/eva/assistant/gemini_live_server.py new file mode 100644 index 00000000..be16174c --- /dev/null +++ b/src/eva/assistant/gemini_live_server.py @@ -0,0 +1,630 @@ +"""Gemini Live AssistantServer for EVA-Bench. + +Bridges between Twilio-framed WebSocket (user simulator) and Google's Gemini Live +API via the google-genai Python SDK. Audio flows: + + User simulator (8 kHz mulaw) + -> 16 kHz PCM16 -> Gemini Live input + Gemini Live output (24 kHz PCM16) + -> 8 kHz mulaw -> User simulator + +All tool calls are executed locally via ToolExecutor; transcription events +from Gemini populate the audit log. +""" + +from __future__ import annotations + +import asyncio +import audioop +import json +import os +import time +from pathlib import Path +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI, WebSocket, WebSocketDisconnect + +from google import genai +from google.genai import types + +from eva.assistant.agentic.audit_log import current_timestamp_ms +from eva.assistant.audio_bridge import ( + FrameworkLogWriter, + MetricsLogWriter, + create_twilio_media_message, + mulaw_8k_to_pcm16_16k, + mulaw_8k_to_pcm16_24k, + parse_twilio_media_message, + pcm16_24k_to_mulaw_8k, + sync_buffer_to_position, +) +from eva.assistant.base_server import INITIAL_MESSAGE, AbstractAssistantServer +from eva.models.agents import AgentConfig, AgentTool +from eva.models.config import PipelineConfig, SpeechToSpeechConfig, AudioLLMConfig +from eva.utils.logging import get_logger +from eva.utils.prompt_manager import PromptManager + +logger = get_logger(__name__) + +# Default recording sample rate (Gemini outputs 24 kHz PCM) +_RECORDING_SAMPLE_RATE = 24000 + + +# --------------------------------------------------------------------------- +# Tool schema helpers +# --------------------------------------------------------------------------- + +def _json_schema_type(python_type: str) -> str: + """Map Python/EVA type names to JSON Schema / Gemini type strings.""" + mapping = { + "string": "STRING", + "str": "STRING", + "integer": "INTEGER", + "int": "INTEGER", + "number": "NUMBER", + "float": "NUMBER", + "boolean": "BOOLEAN", + "bool": "BOOLEAN", + "array": "ARRAY", + "list": "ARRAY", + "object": "OBJECT", + "dict": "OBJECT", + } + return mapping.get(python_type.lower(), "STRING") + + +def _convert_schema_properties(props: dict[str, Any]) -> dict[str, types.Schema]: + """Recursively convert JSON Schema property dicts to Gemini Schema objects.""" + result: dict[str, types.Schema] = {} + for name, defn in props.items(): + if not isinstance(defn, dict): + result[name] = types.Schema(type="STRING") + continue + + schema_type = _json_schema_type(defn.get("type", "string")) + kwargs: dict[str, Any] = {"type": schema_type} + + if "description" in defn: + kwargs["description"] = defn["description"] + if "enum" in defn: + kwargs["enum"] = defn["enum"] + + # Nested object + if schema_type == "OBJECT" and "properties" in defn: + kwargs["properties"] = _convert_schema_properties(defn["properties"]) + + # Array items + if schema_type == "ARRAY" and "items" in defn: + items = defn["items"] + if isinstance(items, dict): + item_type = _json_schema_type(items.get("type", "string")) + item_kwargs: dict[str, Any] = {"type": item_type} + if "properties" in items: + item_kwargs["properties"] = _convert_schema_properties(items["properties"]) + kwargs["items"] = types.Schema(**item_kwargs) + else: + kwargs["items"] = types.Schema(type="STRING") + + result[name] = types.Schema(**kwargs) + return result + + +def _agent_tools_to_gemini(agent: AgentConfig) -> list[types.Tool] | None: + """Convert EVA AgentConfig tools to Gemini FunctionDeclaration list.""" + if not agent.tools: + return None + + declarations: list[types.FunctionDeclaration] = [] + for tool in agent.tools: + properties = _convert_schema_properties(tool.get_parameter_properties()) + required = tool.get_required_param_names() + + params_schema = types.Schema( + type="OBJECT", + properties=properties, + required=required if required else None, + ) + + declarations.append( + types.FunctionDeclaration( + name=tool.function_name, + description=f"{tool.name}: {tool.description}", + parameters=params_schema, + behavior=types.Behavior.BLOCKING, + ) + ) + + if not declarations: + return None + return [types.Tool(function_declarations=declarations)] + + +# --------------------------------------------------------------------------- +# Gemini Live AssistantServer +# --------------------------------------------------------------------------- + +class GeminiLiveAssistantServer(AbstractAssistantServer): + """Bridges Twilio WebSocket <-> Gemini Live API for EVA-Bench evaluation.""" + + def __init__( + self, + current_date_time: str, + pipeline_config: PipelineConfig | SpeechToSpeechConfig | AudioLLMConfig, + agent: AgentConfig, + agent_config_path: str, + scenario_db_path: str, + output_dir: Path, + port: int, + conversation_id: str, + ): + super().__init__( + current_date_time=current_date_time, + pipeline_config=pipeline_config, + agent=agent, + agent_config_path=agent_config_path, + scenario_db_path=scenario_db_path, + output_dir=output_dir, + port=port, + conversation_id=conversation_id, + ) + + # Recording sample rate (Gemini outputs 24 kHz) + self._audio_sample_rate = _RECORDING_SAMPLE_RATE + + # Server state + self._app: Optional[FastAPI] = None + self._server: Optional[uvicorn.Server] = None + self._server_task: Optional[asyncio.Task] = None + self._running = False + + # Gemini model name from s2s_params or default + s2s_params: dict[str, Any] = {} + if isinstance(self.pipeline_config, SpeechToSpeechConfig): + s2s_params = self.pipeline_config.s2s_params or {} + self._model = self.pipeline_config.s2s if isinstance(self.pipeline_config, SpeechToSpeechConfig) else s2s_params.get("model", "gemini-2.0-flash-live-001") + self._voice = s2s_params.get("voice", "Kore") + self._language_code = s2s_params.get("language_code", "en-US") + + # Build system prompt (same pattern as pipecat realtime) + prompt_manager = PromptManager() + self._system_prompt = prompt_manager.get_prompt( + "realtime_agent.system_prompt", + agent_personality=agent.description, + agent_instructions=agent.instructions, + datetime=self.current_date_time, + ) + + # Build Gemini tools + self._gemini_tools = _agent_tools_to_gemini(agent) + + # Framework log writers + self._fw_log: Optional[FrameworkLogWriter] = None + self._metrics_log: Optional[MetricsLogWriter] = None + + # ------------------------------------------------------------------ + # Server lifecycle + # ------------------------------------------------------------------ + + async def start(self) -> None: + """Start the FastAPI WebSocket server (non-blocking).""" + if self._running: + logger.warning("Server already running") + return + + self.output_dir.mkdir(parents=True, exist_ok=True) + self._fw_log = FrameworkLogWriter(self.output_dir) + self._metrics_log = MetricsLogWriter(self.output_dir) + + self._app = FastAPI() + + @self._app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await self._handle_session(websocket) + + @self._app.websocket("/") + async def websocket_root(websocket: WebSocket): + await websocket.accept() + await self._handle_session(websocket) + + config = uvicorn.Config( + self._app, + host="0.0.0.0", + port=self.port, + log_level="warning", + lifespan="off", + ) + self._server = uvicorn.Server(config) + self._running = True + self._server_task = asyncio.create_task(self._server.serve()) + + while not self._server.started: + await asyncio.sleep(0.01) + + logger.info(f"GeminiLive server started on ws://localhost:{self.port}") + + async def stop(self) -> None: + """Stop the server, save outputs.""" + if not self._running: + return + self._running = False + + if self._server: + self._server.should_exit = True + if self._server_task: + try: + await asyncio.wait_for(self._server_task, timeout=5.0) + except asyncio.TimeoutError: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + except (asyncio.CancelledError, KeyboardInterrupt): + pass + self._server = None + self._server_task = None + + await self.save_outputs() + logger.info(f"GeminiLive server stopped on port {self.port}") + + # ------------------------------------------------------------------ + # Gemini client factory + # ------------------------------------------------------------------ + + def _create_genai_client(self) -> genai.Client: + """Create a google-genai Client using Vertex AI or API key.""" + api_key = os.environ.get("GEMINI_API_KEY") + if api_key: + logger.info("Using Gemini API key for authentication") + return genai.Client(api_key=api_key) + + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1") + if project: + logger.info(f"Using Vertex AI (project={project}, location={location})") + return genai.Client(vertexai=True, project=project, location=location) + + # Fallback: let the SDK resolve credentials (e.g. ADC) + logger.info("No explicit credentials; relying on google-genai default resolution") + return genai.Client() + + # ------------------------------------------------------------------ + # Live session configuration + # ------------------------------------------------------------------ + + def _build_live_config(self) -> types.LiveConnectConfig: + """Build the LiveConnectConfig for the Gemini session.""" + config_kwargs: dict[str, Any] = { + "response_modalities": [types.Modality.AUDIO], + "system_instruction": self._system_prompt, + "speech_config": types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=self._voice, + ) + ), + language_code=self._language_code, + ), + "realtime_input_config": types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection( + disabled=False, + start_of_speech_sensitivity=types.StartSensitivity.START_SENSITIVITY_LOW, + end_of_speech_sensitivity=types.EndSensitivity.END_SENSITIVITY_LOW, + silence_duration_ms=200, + ), + activity_handling=types.ActivityHandling.START_OF_ACTIVITY_INTERRUPTS, + ), + "input_audio_transcription": types.AudioTranscriptionConfig(), + "output_audio_transcription": types.AudioTranscriptionConfig(), + } + if self._gemini_tools: + config_kwargs["tools"] = self._gemini_tools + + return types.LiveConnectConfig(**config_kwargs) + + # ------------------------------------------------------------------ + # Session handler + # ------------------------------------------------------------------ + + async def _handle_session(self, websocket: WebSocket) -> None: + """Bridge a single Twilio WebSocket session with Gemini Live.""" + logger.info("Client connected to GeminiLive server") + + stream_sid: str = self.conversation_id + client = self._create_genai_client() + live_config = self._build_live_config() + + # Track Twilio stream state + twilio_connected = True + + # Accumulate assistant speech text per turn + _assistant_turn_text: list[str] = [] + _user_turn_text: list[str] = [] + + _in_model_turn = False + _user_speaking = False + + _user_speech_end_ts: Optional[float] = None + _first_audio_in_turn = False + + try: + async with client.aio.live.connect( + model=self._model, config=live_config + ) as session: + logger.info(f"Gemini Live session connected (model={self._model})") + + # Trigger the initial greeting using realtime text input. + # send_client_content with Content turns is not supported by + # some Live models (e.g. gemini-3.1-flash-live-preview), but + # send_realtime_input(text=...) works universally. + await session.send_realtime_input(text=f"Please greet with: {INITIAL_MESSAGE}") + self._fw_log.turn_start() + + # ----- Concurrent tasks ----- + async def _forward_user_audio() -> None: + """Read Twilio WS messages, convert audio, send to Gemini.""" + nonlocal stream_sid, twilio_connected + try: + while twilio_connected and self._running: + try: + raw = await asyncio.wait_for( + websocket.receive_text(), timeout=1.0 + ) + except asyncio.TimeoutError: + continue + + # Parse Twilio JSON envelope + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + + event = msg.get("event") + if event == "start": + stream_sid = msg.get("start", {}).get("streamSid", stream_sid) + logger.info(f"Twilio stream started: {stream_sid}") + continue + elif event == "stop": + logger.info("Twilio stream stopped") + twilio_connected = False + break + elif event == "media": + # Extract raw mulaw bytes + mulaw_bytes = parse_twilio_media_message(raw) + if mulaw_bytes is None: + continue + + # Convert 8 kHz mulaw -> 16 kHz PCM for Gemini + pcm_16k = mulaw_8k_to_pcm16_16k(mulaw_bytes) + + pcm_24k = mulaw_8k_to_pcm16_24k(mulaw_bytes) + if not _in_model_turn: + sync_buffer_to_position(self.assistant_audio_buffer, len(self.user_audio_buffer)) + self.user_audio_buffer.extend(pcm_24k) + + # Send to Gemini + await session.send_realtime_input( + audio=types.Blob( + data=pcm_16k, + mime_type="audio/pcm;rate=16000", + ) + ) + except WebSocketDisconnect: + logger.info("Twilio WebSocket disconnected") + twilio_connected = False + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error in user audio forwarder: {e}", exc_info=True) + finally: + twilio_connected = False + + async def _process_gemini_events() -> None: + """Consume events from the Gemini Live session.""" + nonlocal _assistant_turn_text, _user_turn_text + nonlocal _in_model_turn, _user_speaking, _user_speech_end_ts, _first_audio_in_turn + nonlocal twilio_connected + + logger.info("Gemini event processor started") + event_count = 0 + try: + # Use manual receive loop instead of `async for ... in session.receive()` + # because the iterator exits after turn_complete (returns None), + # closing the session prematurely. The manual loop keeps the session + # alive between model turns. + while self._running: + try: + response = await asyncio.wait_for(session._receive(), timeout=2.0) + except asyncio.TimeoutError: + continue + if response is None: + continue + if not self._running: + break + + event_count += 1 + + # --- Server content (audio, transcriptions, turn signals) --- + if response.server_content: + sc = response.server_content + + # Model audio output + if sc.model_turn: + if not _in_model_turn: + _in_model_turn = True + _first_audio_in_turn = True + _assistant_turn_text = [] + self._fw_log.turn_start() + + for part in sc.model_turn.parts: + if part.inline_data and part.inline_data.data: + pcm_24k = bytes(part.inline_data.data) + + # Skip tiny chunks that can't be resampled + if len(pcm_24k) < 6: + continue + + # Measure latency: time from user speech end + # to first assistant audio chunk + if _first_audio_in_turn and _user_speech_end_ts is not None: + latency = time.time() - _user_speech_end_ts + self._latency_measurements.append(latency) + self._metrics_log.write_ttfb_metric( + "gemini_live", latency, component="e2e_ttfb" + ) + logger.debug(f"Response latency: {latency:.3f}s") + _user_speech_end_ts = None + _first_audio_in_turn = False + + if not _user_speaking: + sync_buffer_to_position(self.user_audio_buffer, len(self.assistant_audio_buffer)) + self.assistant_audio_buffer.extend(pcm_24k) + + # Convert to 8 kHz mulaw and send in + # small chunks so the user simulator's + # silence-detection timing works correctly. + if twilio_connected: + try: + mulaw = pcm16_24k_to_mulaw_8k(pcm_24k) + except Exception as conv_err: + logger.warning(f"Audio conversion error ({len(pcm_24k)} bytes): {conv_err}") + continue + _MULAW_CHUNK = 160 + offset = 0 + while offset < len(mulaw): + chunk = mulaw[offset:offset + _MULAW_CHUNK] + offset += _MULAW_CHUNK + twilio_msg = create_twilio_media_message( + stream_sid, chunk + ) + try: + await websocket.send_text(twilio_msg) + except Exception: + twilio_connected = False + break + + # Turn complete + if sc.turn_complete: + logger.debug("Gemini turn complete") + full_text = " ".join(_assistant_turn_text).strip() + if full_text: + self.audit_log.append_assistant_output(full_text) + self._fw_log.tts_text(full_text) + self._fw_log.llm_response(full_text) + self._fw_log.turn_end(was_interrupted=False) + _in_model_turn = False + _assistant_turn_text = [] + + # Barge-in / interruption + if sc.interrupted: + _user_speaking = True + logger.debug("Gemini turn interrupted (barge-in)") + full_text = " ".join(_assistant_turn_text).strip() + if full_text: + self.audit_log.append_assistant_output( + full_text + " [interrupted]" + ) + self._fw_log.tts_text(full_text) + self._fw_log.turn_end(was_interrupted=True) + _in_model_turn = False + _assistant_turn_text = [] + + # Input transcription (user speech) + if sc.input_transcription: + _user_speaking = False + text = sc.input_transcription.text or "" + if text.strip(): + logger.info(f"User transcription: {text.strip()}") + self.audit_log.append_user_input(text.strip()) + _user_speech_end_ts = time.time() + + # Output transcription (model speech) + if sc.output_transcription: + text = sc.output_transcription.text or "" + if text.strip(): + _assistant_turn_text.append(text.strip()) + logger.debug(f"Assistant transcription chunk: {text.strip()}") + + # --- Tool calls --- + if response.tool_call: + for fc in response.tool_call.function_calls: + tool_name = fc.name + tool_args = dict(fc.args) if fc.args else {} + logger.info(f"Tool call: {tool_name}({json.dumps(tool_args)})") + + # Record in audit log + self.audit_log.append_realtime_tool_call( + tool_name, tool_args + ) + + # Execute tool + result = await self.tool_handler.execute( + tool_name, tool_args + ) + logger.info(f"Tool result: {tool_name} -> {json.dumps(result)}") + self.audit_log.append_tool_response(tool_name, result) + + # Send result back to Gemini + await session.send_tool_response( + function_responses=[ + types.FunctionResponse( + id=fc.id, + name=fc.name, + response=result, + ) + ] + ) + + # --- Usage metadata --- + if response.usage_metadata: + um = response.usage_metadata + prompt_tokens = getattr(um, "prompt_token_count", 0) or 0 + completion_tokens = ( + getattr(um, "candidates_token_count", 0) or 0 + ) + if prompt_tokens or completion_tokens: + self._metrics_log.write_token_usage( + processor="gemini_live", + model=self._model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error in Gemini event processor: {e}", exc_info=True) + + # Run both tasks; when either exits, cancel the other + user_task = asyncio.create_task(_forward_user_audio()) + gemini_task = asyncio.create_task(_process_gemini_events()) + + done, pending = await asyncio.wait( + [user_task, gemini_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Log which task finished first + for task in done: + task_name = "user_audio" if task is user_task else "gemini_events" + exc = task.exception() + if exc: + logger.error(f"Task '{task_name}' failed: {exc}", exc_info=exc) + else: + logger.info(f"Task '{task_name}' completed normally") + + for task in pending: + task_name = "user_audio" if task is user_task else "gemini_events" + logger.info(f"Cancelling pending task '{task_name}'") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + except Exception as e: + logger.error(f"Gemini Live session error: {e}", exc_info=True) + finally: + logger.info("Client disconnected from GeminiLive server") diff --git a/src/eva/assistant/openai_realtime_server.py b/src/eva/assistant/openai_realtime_server.py new file mode 100644 index 00000000..210389d3 --- /dev/null +++ b/src/eva/assistant/openai_realtime_server.py @@ -0,0 +1,778 @@ +"""OpenAI Realtime API assistant server implementation. + +Uses the OpenAI Python SDK's Realtime API (client.beta.realtime.connect()) +to bridge audio between a Twilio-framed WebSocket (user simulator) and the +OpenAI Realtime model. Handles tool calls via the local ToolExecutor and +records all conversation events in the audit log. +""" + +import asyncio +import base64 +import json +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from openai import AsyncOpenAI + +from eva.assistant.audio_bridge import ( + FrameworkLogWriter, + MetricsLogWriter, + create_twilio_media_message, + mulaw_8k_to_pcm16_24k, + parse_twilio_media_message, + pcm16_24k_to_mulaw_8k, + sync_buffer_to_position, +) +from eva.assistant.base_server import INITIAL_MESSAGE, AbstractAssistantServer +from eva.utils.logging import get_logger +from eva.utils.prompt_manager import PromptManager + +logger = get_logger(__name__) + +# OpenAI Realtime operates at 24 kHz 16-bit mono PCM +OPENAI_SAMPLE_RATE = 24000 + +# Audio output pacing: send 160-byte mulaw chunks (20ms at 8kHz) at real-time rate +# so the user simulator's silence detection works correctly. +MULAW_CHUNK_SIZE = 160 # bytes per chunk (20ms at 8kHz, 1 byte per sample) +MULAW_CHUNK_DURATION_S = 0.02 # 20ms per chunk + + +def _wall_ms() -> str: + """Return current wall-clock time as epoch-milliseconds string.""" + return str(int(round(time.time() * 1000))) + + +@dataclass +class _UserTurnRecord: + """Tracks state for a single user speech turn.""" + + speech_started_wall_ms: str = "" + speech_stopped_wall_ms: str = "" + transcript: str = "" + flushed: bool = False + + +@dataclass +class _AssistantResponseState: + """Accumulates state for the current assistant response.""" + + transcript_parts: list[str] = field(default_factory=list) + transcript_done_text: str = "" # Final text from response.audio_transcript.done + first_audio_wall_ms: Optional[str] = None + responding: bool = False + has_function_calls: bool = False + + +class OpenAIRealtimeAssistantServer(AbstractAssistantServer): + """Assistant server backed by the OpenAI Realtime API. + + Exposes a local WebSocket at ``ws://localhost:{port}/ws`` using the Twilio + frame format so the user simulator can connect as if talking to Twilio. + Internally bridges audio between Twilio (8 kHz mulaw) and OpenAI Realtime + (24 kHz PCM16 base64). + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self._audio_sample_rate = OPENAI_SAMPLE_RATE + + self._app: Optional[FastAPI] = None + self._server: Optional[uvicorn.Server] = None + self._server_task: Optional[asyncio.Task] = None + self._running: bool = False + + self._fw_log: Optional[FrameworkLogWriter] = None + self._metrics_log: Optional[MetricsLogWriter] = None + + prompt_manager = PromptManager() + self._system_prompt: str = prompt_manager.get_prompt( + "realtime_agent.system_prompt", + agent_personality=self.agent.description, + agent_instructions=self.agent.instructions, + datetime=self.current_date_time, + ) + + self._realtime_tools: list[dict] = self._build_realtime_tools() + + self._user_turn: Optional[_UserTurnRecord] = None + self._assistant_state = _AssistantResponseState() + self._stream_sid: str = "" + + self._user_speaking: bool = False + self._bot_speaking: bool = False + self._user_frame_count: int = 0 + self._delta_count: int = 0 + + # Audio output pacing: absolute time target for next chunk send + self._next_chunk_send_time: float = 0.0 + + self._model: str = self.pipeline_config.s2s + + async def start(self) -> None: + """Start the FastAPI WebSocket server.""" + if self._running: + logger.warning("Server already running") + return + + self.output_dir.mkdir(parents=True, exist_ok=True) + self._fw_log = FrameworkLogWriter(self.output_dir) + self._metrics_log = MetricsLogWriter(self.output_dir) + + self._app = FastAPI() + + @self._app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await self._handle_session(websocket) + + @self._app.websocket("/") + async def websocket_root(websocket: WebSocket): + await websocket.accept() + await self._handle_session(websocket) + + config = uvicorn.Config( + self._app, + host="0.0.0.0", + port=self.port, + log_level="warning", + lifespan="off", + ) + self._server = uvicorn.Server(config) + self._running = True + self._server_task = asyncio.create_task(self._server.serve()) + + while not self._server.started: + await asyncio.sleep(0.01) + + logger.info(f"OpenAI Realtime server started on ws://localhost:{self.port}") + + async def stop(self) -> None: + """Stop the server and save all outputs.""" + if not self._running: + return + + self._running = False + + if self._server: + self._server.should_exit = True + if self._server_task: + try: + await asyncio.wait_for(self._server_task, timeout=5.0) + except asyncio.TimeoutError: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + except (asyncio.CancelledError, KeyboardInterrupt): + pass + self._server = None + self._server_task = None + + await self.save_outputs() + logger.info(f"OpenAI Realtime server stopped on port {self.port}") + + async def save_outputs(self) -> None: + """Save all outputs including mixed audio.""" + await super().save_outputs() + + async def _handle_session(self, websocket: WebSocket) -> None: + """Handle a single WebSocket session. + + 1. Accept Twilio WS connection + 2. Connect to OpenAI Realtime API + 3. Configure session (instructions, tools, voice, VAD) + 4. Run two concurrent tasks: + a. Forward user audio: Twilio WS -> decode mulaw -> PCM16 24kHz base64 -> OpenAI + b. Process OpenAI events: async for event in conn -> handle each type + 5. On tool call: execute via self.tool_handler, send result back + 6. On audio: decode base64 PCM16 -> record -> encode mulaw -> send to Twilio WS + """ + logger.info("Client connected to OpenAI Realtime server") + + # Reset per-session state + self._user_turn = None + self._assistant_state = _AssistantResponseState() + self._stream_sid = self.conversation_id + self._user_speaking = False + self._bot_speaking = False + + api_key = self.pipeline_config.s2s_params.get("api_key") + if not api_key: + raise ValueError("API key required for openai realtime") + client = AsyncOpenAI(api_key=api_key) + + try: + async with client.realtime.connect(model=self._model) as conn: + # Configure the session + await conn.session.update( + session={ + "type": "realtime", + "output_modalities": ["audio"], + "instructions": self._system_prompt, + "audio": { + "output": { + "voice": self.pipeline_config.s2s_params.get("voice", "marin"), + "format": {"type": "audio/pcm", "rate": 24000}, + }, + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": self.pipeline_config.s2s_params.get("vad_settings", {}).get( + "type", "server_vad" + ), + "threshold": self.pipeline_config.s2s_params.get("vad_settings", {}).get( + "threshold", 0.5 + ), + "prefix_padding_ms": self.pipeline_config.s2s_params.get("vad_settings", {}).get( + "prefix_padding_ms", 300 + ), + "silence_duration_ms": self.pipeline_config.s2s_params.get("vad_settings", {}).get( + "silence_duration_ms", 200 + ), + }, + "transcription": { + "model": self.pipeline_config.s2s_params.get("transcription_model", "whisper-1") + }, + }, + }, + "tools": self._realtime_tools, + } + ) + + # Trigger the initial greeting + await conn.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"Say: '{INITIAL_MESSAGE}'", + } + ], + } + ) + await conn.response.create() + + # Run forwarding tasks concurrently + forward_task = asyncio.create_task(self._forward_user_audio(websocket, conn)) + receive_task = asyncio.create_task(self._process_openai_events(conn, websocket)) + + done, pending = await asyncio.wait( + [forward_task, receive_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Check for exceptions in completed tasks + for task in done: + if task.exception(): + logger.error(f"Session task failed: {task.exception()}") + + except Exception as e: + logger.error(f"OpenAI Realtime session error: {e}", exc_info=True) + finally: + logger.info("Client disconnected from OpenAI Realtime server") + + # ── User audio forwarding (Twilio WS -> OpenAI) ────────────────── + + async def _forward_user_audio(self, websocket: WebSocket, conn: Any) -> None: + """Read Twilio media frames and forward audio to OpenAI Realtime.""" + try: + while True: + raw = await websocket.receive_text() + data = json.loads(raw) + event_type = data.get("event") + + if event_type == "start": + # Twilio stream start - extract streamSid + self._stream_sid = data.get("start", {}).get("streamSid", self.conversation_id) + logger.debug(f"Twilio stream started: streamSid={self._stream_sid}") + continue + + if event_type == "stop": + logger.debug("Twilio stream stopped") + break + + if event_type != "media": + continue + + # Extract raw mulaw audio bytes + mulaw_bytes = parse_twilio_media_message(raw) + if mulaw_bytes is None: + continue + + # Convert 8kHz mulaw -> 24kHz PCM16 + pcm16_24k = mulaw_8k_to_pcm16_24k(mulaw_bytes) + + asst_before = len(self.assistant_audio_buffer) + synced = 0 + if not self._bot_speaking: + sync_target = len(self.user_audio_buffer) + sync_buffer_to_position(self.assistant_audio_buffer, sync_target) + synced = len(self.assistant_audio_buffer) - asst_before + self.user_audio_buffer.extend(pcm16_24k) + self._user_frame_count += 1 + if self._user_frame_count % 50 == 0: + diff = len(self.user_audio_buffer) - len(self.assistant_audio_buffer) + diff_ms = diff / (OPENAI_SAMPLE_RATE * 2) * 1000 + logger.debug( + f"[ALIGN DEBUG] user_frame #{self._user_frame_count}: " + f"user={len(self.user_audio_buffer)} asst={len(self.assistant_audio_buffer)} " + f"diff={diff}({diff_ms:.0f}ms) bot_spk={self._bot_speaking} " + f"usr_spk={self._user_speaking} added={len(pcm16_24k)} synced={synced}" + ) + + # Encode as base64 and send to OpenAI + audio_b64 = base64.b64encode(pcm16_24k).decode("ascii") + await conn.input_audio_buffer.append(audio=audio_b64) + + except WebSocketDisconnect: + logger.debug("Twilio WebSocket disconnected") + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error forwarding user audio: {e}", exc_info=True) + + # ── OpenAI event processing ─────────────────────────────────────── + + async def _process_openai_events(self, conn: Any, websocket: WebSocket) -> None: + """Process events from the OpenAI Realtime connection.""" + try: + async for event in conn: + try: + await self._handle_openai_event(event, conn, websocket) + except Exception as e: + logger.error(f"Error handling event {getattr(event, 'type', '?')}: {e}", exc_info=True) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error in OpenAI event loop: {e}", exc_info=True) + + async def _handle_openai_event(self, event: Any, conn: Any, websocket: WebSocket) -> None: + """Dispatch a single OpenAI Realtime event.""" + event_type = getattr(event, "type", "") + + match event_type: + case "session.created": + logger.info("OpenAI Realtime session created") + + case "session.updated": + logger.debug("OpenAI Realtime session updated") + + case "input_audio_buffer.speech_started": + await self._on_speech_started(event) + + case "input_audio_buffer.speech_stopped": + await self._on_speech_stopped(event) + + case "conversation.item.input_audio_transcription.completed": + await self._on_transcription_completed(event) + + case "conversation.item.input_audio_transcription.delta": + logger.debug(f"Transcription delta: {getattr(event, 'delta', '')}") + + case "conversation.item.input_audio_transcription.failed": + error_info = getattr(event, "error", "") + logger.warning(f"Transcription failed: {error_info}") + # Gracefully handle transcription failure (e.g. API key lacks + # whisper-1 access). If a user turn was active but has no + # transcript yet, record a placeholder so the turn is not lost. + if self._user_turn and not self._user_turn.flushed: + timestamp_ms = self._user_turn.speech_started_wall_ms or None + self.audit_log.append_user_input( + "[user speech - transcription unavailable]", + timestamp_ms=timestamp_ms, + ) + self._user_turn.flushed = True + + case "response.output_audio.delta": + await self._on_audio_delta(event, websocket) + + case "response.output_audio_transcript.delta": + self._on_transcript_delta(event) + + case "response.output_audio_transcript.done": + self._on_transcript_done(event) + + case "response.function_call_arguments.done": + await self._on_function_call_done(event, conn) + + case "response.done": + await self._on_response_done(event) + + case "error": + error_data = getattr(event, "error", None) + logger.error(f"OpenAI Realtime error: {error_data}") + + case _: + logger.debug(f"Unhandled OpenAI event: {event_type}") + + # ── Event handlers ──────────────────────────────────────────────── + + async def _on_speech_started(self, event: Any) -> None: + """Handle input_audio_buffer.speech_started.""" + self._user_speaking = True + diff = len(self.user_audio_buffer) - len(self.assistant_audio_buffer) + diff_ms = diff / (OPENAI_SAMPLE_RATE * 2) * 1000 + logger.debug( + f"[ALIGN DEBUG] speech_started: user={len(self.user_audio_buffer)} " + f"asst={len(self.assistant_audio_buffer)} diff={diff}({diff_ms:.0f}ms) " + f"bot_spk={self._bot_speaking}" + ) + wall = _wall_ms() + + # If assistant was responding, flush interrupted response + if self._assistant_state.responding and self._assistant_state.transcript_parts: + partial_text = "".join(self._assistant_state.transcript_parts) + " [interrupted]" + self.audit_log.append_assistant_output( + partial_text, + timestamp_ms=self._assistant_state.first_audio_wall_ms, + ) + if self._fw_log: + self._fw_log.tts_text(partial_text) + self._fw_log.turn_end(was_interrupted=True) + logger.debug(f"Flushed interrupted assistant response: {partial_text[:60]}...") + self._assistant_state = _AssistantResponseState() + + # Start new user turn only if previous one was flushed (or doesn't exist) + # This preserves the original timestamp when VAD fires multiple speech_started + # events during a single logical user utterance (due to brief pauses) + if not self._user_turn or self._user_turn.flushed: + self._user_turn = _UserTurnRecord(speech_started_wall_ms=wall) + if self._fw_log: + self._fw_log.turn_start(timestamp_ms=int(wall)) + logger.debug(f"Speech started at {wall} (new turn)") + else: + logger.debug(f"Speech started at {wall} (continuing existing turn)") + + async def _on_speech_stopped(self, event: Any) -> None: + """Handle input_audio_buffer.speech_stopped.""" + self._user_speaking = False + diff = len(self.user_audio_buffer) - len(self.assistant_audio_buffer) + diff_ms = diff / (OPENAI_SAMPLE_RATE * 2) * 1000 + logger.info( + f"[ALIGN DEBUG] speech_stopped: user={len(self.user_audio_buffer)} " + f"asst={len(self.assistant_audio_buffer)} diff={diff}({diff_ms:.0f}ms) " + f"bot_spk={self._bot_speaking}" + ) + wall = _wall_ms() + if self._user_turn: + self._user_turn.speech_stopped_wall_ms = wall + else: + self._user_turn = _UserTurnRecord(speech_stopped_wall_ms=wall) + + # Record latency measurement start + self._speech_stopped_time = time.time() + logger.debug(f"Speech stopped at {wall}") + + async def _on_transcription_completed(self, event: Any) -> None: + """Handle conversation.item.input_audio_transcription.completed.""" + transcript = getattr(event, "transcript", "") or "" + transcript = transcript.strip() + + if not transcript: + logger.debug("Empty transcription, skipping") + return + + timestamp_ms = None + if self._user_turn: + timestamp_ms = self._user_turn.speech_started_wall_ms or None + self._user_turn.transcript = transcript + self._user_turn.flushed = True + + self.audit_log.append_user_input(transcript, timestamp_ms=timestamp_ms) + logger.debug(f"User transcription: {transcript}...") + + async def _on_audio_delta(self, event: Any, websocket: WebSocket) -> None: + """Handle response.audio.delta - assistant audio chunk.""" + delta_b64 = getattr(event, "delta", "") or "" + if not delta_b64: + return + + pcm16_bytes = base64.b64decode(delta_b64) + now = time.time() + + if self._assistant_state.first_audio_wall_ms is None: + self._assistant_state.first_audio_wall_ms = _wall_ms() + self._assistant_state.responding = True + self._bot_speaking = True + + if hasattr(self, "_speech_stopped_time"): + latency = now - self._speech_stopped_time + self._latency_measurements.append(latency) + if self._metrics_log: + self._metrics_log.write_ttfb_metric( + processor="openai_realtime", + value_seconds=latency, + model=self._model, + ) + logger.debug(f"Response latency: {latency:.3f}s") + + user_before = len(self.user_audio_buffer) + synced = 0 + if not self._user_speaking: + sync_buffer_to_position(self.user_audio_buffer, len(self.assistant_audio_buffer)) + synced = len(self.user_audio_buffer) - user_before + self.assistant_audio_buffer.extend(pcm16_bytes) + self._delta_count += 1 + if self._delta_count % 10 == 0: + diff = len(self.user_audio_buffer) - len(self.assistant_audio_buffer) + diff_ms = diff / (OPENAI_SAMPLE_RATE * 2) * 1000 + logger.debug( + f"[ALIGN DEBUG] audio_delta #{self._delta_count}: " + f"user={len(self.user_audio_buffer)} asst={len(self.assistant_audio_buffer)} " + f"diff={diff}({diff_ms:.0f}ms) bot_spk={self._bot_speaking} " + f"usr_spk={self._user_speaking} added={len(pcm16_bytes)} synced_user={synced}" + ) + + # Convert 24kHz PCM16 -> 8kHz mulaw and send in real-time-paced chunks. + # Each 160-byte chunk = 20ms of audio at 8kHz. We sleep between sends + # so the user simulator receives audio at playback rate, which ensures + # its silence-based audio_start/audio_end detection works correctly. + try: + mulaw_bytes = pcm16_24k_to_mulaw_8k(pcm16_bytes) + now = time.monotonic() + + # Initialize pacing clock on first chunk of a new response + if self._next_chunk_send_time <= now: + self._next_chunk_send_time = now + + offset = 0 + while offset < len(mulaw_bytes): + chunk = mulaw_bytes[offset : offset + MULAW_CHUNK_SIZE] + offset += MULAW_CHUNK_SIZE + twilio_msg = create_twilio_media_message(self._stream_sid, chunk) + await websocket.send_text(twilio_msg) + + # Advance absolute clock and sleep until next send time + self._next_chunk_send_time += MULAW_CHUNK_DURATION_S + sleep_duration = self._next_chunk_send_time - time.monotonic() + if sleep_duration > 0: + await asyncio.sleep(sleep_duration) + except Exception as e: + logger.error(f"Error sending audio to Twilio WS: {e}") + + def _on_transcript_delta(self, event: Any) -> None: + """Handle response.audio_transcript.delta - incremental assistant text.""" + delta = getattr(event, "delta", "") or "" + if delta: + self._assistant_state.transcript_parts.append(delta) + + def _on_transcript_done(self, event: Any) -> None: + """Handle response.audio_transcript.done - full assistant transcript. + + This is the most reliable source of what the model actually said. + Store it so _on_response_done can use it if delta accumulation failed. + """ + transcript = getattr(event, "transcript", "") or "" + if transcript: + self._assistant_state.transcript_done_text = transcript.strip() + logger.debug(f"Assistant transcript done: {transcript}...") + if self._fw_log: + self._fw_log.tts_text(transcript) + + async def _on_function_call_done(self, event: Any, conn: Any) -> None: + """Handle response.function_call_arguments.done - execute tool call.""" + call_id = getattr(event, "call_id", "") + func_name = getattr(event, "name", "") + arguments_str = getattr(event, "arguments", "{}") or "{}" + + try: + arguments = json.loads(arguments_str) + except json.JSONDecodeError: + arguments = {} + + logger.info(f"Tool call: {func_name}({json.dumps(arguments)})") + self._assistant_state.has_function_calls = True + + # Record in audit log + self.audit_log.append_realtime_tool_call(func_name, arguments) + + # Execute tool + result = await self.tool_handler.execute(func_name, arguments) + + # Record tool response + self.audit_log.append_tool_response(func_name, result) + + if self._fw_log: + self._fw_log.write( + "tool_call", + { + "frame": "tool_call", + "tool_name": func_name, + "arguments": arguments, + "result": result, + }, + ) + + # Send function call output back to OpenAI + await conn.conversation.item.create( + item={ + "type": "function_call_output", + "call_id": call_id, + "output": json.dumps(result), + } + ) + + # Trigger next response after tool result + await conn.response.create() + + async def _on_response_done(self, event: Any) -> None: + """Handle response.done - assistant response complete. + + Following the pipecat InstrumentedRealtimeLLMService pattern: + - Only call append_assistant_output() (no append_llm_call) + - Token usage goes to pipecat_metrics.jsonl only + """ + # Extract usage metrics + response = getattr(event, "response", None) + if response: + usage = getattr(response, "usage", None) + if usage and self._metrics_log: + input_tokens = getattr(usage, "input_tokens", 0) or 0 + output_tokens = getattr(usage, "output_tokens", 0) or 0 + self._metrics_log.write_token_usage( + processor="openai_realtime", + model=self._model, + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + ) + + # Skip cancelled responses - these were interrupted and not fully spoken + if response and getattr(response, "status", None) == "cancelled": + logger.debug("response_done: cancelled response, skipping transcript entry") + self._reset_assistant_state() + return + + has_function_calls = self._response_has_function_calls(event) + + # Build transcript text from best available source: + # 1. response.audio_transcript.done text (most reliable) + # 2. Accumulated response.audio_transcript.delta parts + # 3. Text extracted from response.done output items + content = self._assistant_state.transcript_done_text + if not content: + content = "".join(self._assistant_state.transcript_parts).strip() + if not content: + content = self._extract_response_text(event) + + audio_was_streamed = self._assistant_state.first_audio_wall_ms is not None + + # Skip tool-call-only responses (nothing spoken) + if not content and has_function_calls: + logger.debug("response_done: tool-call-only response, skipping assistant entry") + self._reset_assistant_state() + return + + # Skip mixed responses where audio was not streamed + if content and not audio_was_streamed and has_function_calls: + logger.debug(f"response_done: mixed response with no audio, skipping: '{content[:60]}...'") + self._reset_assistant_state() + return + + # If audio was streamed but we have no transcript at all, skip rather + # than pollute the audit log with a placeholder. The audio recording + # still captures what was said. + if not content and audio_was_streamed: + logger.debug("response_done: audio streamed but no transcript available, skipping text entry") + self._reset_assistant_state() + return + + if not content: + # No audio, no text, no function calls — nothing to log + self._reset_assistant_state() + return + + # Log assistant output (single entry — no append_llm_call) + timestamp = self._assistant_state.first_audio_wall_ms or _wall_ms() + self.audit_log.append_assistant_output(content, timestamp_ms=timestamp) + + if self._fw_log: + self._fw_log.llm_response(content) + self._fw_log.turn_end(was_interrupted=False) + + logger.debug(f"response_done: '{content[:60]}...'") + self._reset_assistant_state() + + # ── Helpers ─────────────────────────────────────────────────────── + + def _reset_assistant_state(self) -> None: + """Clear accumulated assistant response state.""" + audio_was_streamed = self._assistant_state.first_audio_wall_ms is not None + diff = len(self.user_audio_buffer) - len(self.assistant_audio_buffer) + diff_ms = diff / (OPENAI_SAMPLE_RATE * 2) * 1000 + logger.debug( + f"[ALIGN DEBUG] reset_state: user={len(self.user_audio_buffer)} " + f"asst={len(self.assistant_audio_buffer)} diff={diff}({diff_ms:.0f}ms) " + f"audio_streamed={audio_was_streamed} bot_spk={self._bot_speaking}" + ) + if audio_was_streamed: + self._bot_speaking = False + self._assistant_state = _AssistantResponseState() + + def _build_realtime_tools(self) -> list[dict]: + """Convert agent tools to OpenAI Realtime session tool format. + + The Realtime API session.tools expects a flat structure: + {type, name, description, parameters: {type, properties, required}} + """ + tools: list[dict] = [] + if not self.agent.tools: + return tools + + for tool in self.agent.tools: + tools.append( + { + "type": "function", + "name": tool.function_name, + "description": f"{tool.name}: {tool.description}", + "parameters": { + "type": "object", + "properties": tool.get_parameter_properties(), + "required": tool.get_required_param_names(), + }, + } + ) + return tools + + @staticmethod + def _response_has_function_calls(event: Any) -> bool: + """Return True if the response.done event contains function_call outputs.""" + response = getattr(event, "response", None) + if not response: + return False + output_items = getattr(response, "output", None) or [] + return any(getattr(item, "type", "") == "function_call" for item in output_items) + + @staticmethod + def _extract_response_text(event: Any) -> str: + """Extract text content from response.done output items.""" + response = getattr(event, "response", None) + if not response: + return "" + + output_items = getattr(response, "output", None) or [] + text_parts: list[str] = [] + + for item in output_items: + content_list = getattr(item, "content", None) or [] + for part in content_list: + part_type = getattr(part, "type", "") + if part_type in ("audio", "text"): + transcript = getattr(part, "transcript", None) or getattr(part, "text", None) or "" + if transcript: + text_parts.append(transcript) + + return "".join(text_parts).strip() diff --git a/src/eva/assistant/server.py b/src/eva/assistant/pipecat_server.py similarity index 86% rename from src/eva/assistant/server.py rename to src/eva/assistant/pipecat_server.py index 57a0fc2e..b4824cf5 100644 --- a/src/eva/assistant/server.py +++ b/src/eva/assistant/pipecat_server.py @@ -6,9 +6,8 @@ import asyncio import json -import wave from pathlib import Path -from typing import Any, Optional +from typing import Optional import uvicorn from fastapi import FastAPI, WebSocket @@ -45,7 +44,8 @@ from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies from pipecat.utils.time import time_now_iso8601 -from eva.assistant.agentic.audit_log import AuditLog, current_timestamp_ms +from eva.assistant.agentic.audit_log import current_timestamp_ms +from eva.assistant.base_server import INITIAL_MESSAGE, AbstractAssistantServer from eva.assistant.pipeline.agent_processor import BenchmarkAgentProcessor, UserAudioCollector, UserObserver from eva.assistant.pipeline.audio_llm_processor import ( AudioLLMProcessor, @@ -62,7 +62,6 @@ create_tts_service, ) from eva.assistant.services.llm import LiteLLMClient -from eva.assistant.tools.tool_executor import ToolExecutor from eva.models.agents import AgentConfig from eva.models.config import AudioLLMConfig, PipelineConfig, SpeechToSpeechConfig from eva.utils.logging import get_logger @@ -77,10 +76,8 @@ # Should be larger than pipecat's VAD start_secs (0.2s) to account for VAD latency. VAD_PRE_SPEECH_BUFFER_SECS = 0.5 -INITIAL_MESSAGE = "Hello! How can I help you today?" - -class AssistantServer: +class PipecatAssistantServer(AbstractAssistantServer): """Pipecat-based WebSocket server for the assistant in voice conversations. This server: @@ -113,35 +110,24 @@ def __init__( port: Port to listen on conversation_id: Unique ID for this conversation """ - self.pipeline_config = pipeline_config - self.agent: AgentConfig = agent - self.agent_config_path = agent_config_path - self.scenario_db_path = scenario_db_path - self.output_dir = Path(output_dir) - self.port = port - self.conversation_id = conversation_id - self.current_date_time = current_date_time - - # Components (initialized on start) - self.audit_log = AuditLog() - self.agentic_system = None # Will be set in _handle_session - - # Initialize Python-based tool executor - self.tool_handler = ToolExecutor( - tool_config_path=agent_config_path, + super().__init__( + current_date_time=current_date_time, + pipeline_config=pipeline_config, + agent=agent, + agent_config_path=agent_config_path, scenario_db_path=scenario_db_path, - tool_module_path=self.agent.tool_module_path, - current_date_time=self.current_date_time, + output_dir=output_dir, + port=port, + conversation_id=conversation_id, ) + self.agentic_system = None # Will be set in _handle_session + # Wall-clock captured at on_user_turn_started for non-instrumented S2S models self._user_turn_started_wall_ms: Optional[str] = None - # Audio buffer for accumulating audio data - self._audio_buffer = bytearray() + # Override audio sample rate for pipecat self._audio_sample_rate = SAMPLE_RATE - self.user_audio_buffer = bytearray() - self.assistant_audio_buffer = bytearray() # Server state self._app = None @@ -151,7 +137,6 @@ def __init__( self._task: Optional[PipelineTask] = None self._running = False self.num_seconds = 0 - self._latency_measurements: list[float] = [] self._metrics_observer: Optional[MetricsFileObserver] = None self.non_instrumented_realtime_llm = False @@ -230,7 +215,7 @@ async def stop(self) -> None: self._server_task = None # Save outputs - await self._save_outputs() + await self.save_outputs() logger.info(f"Assistant server stopped on port {self.port}") @@ -757,112 +742,18 @@ def _current_iso_timestamp() -> str: """Return the current time as an ISO 8601 string with timezone.""" return time_now_iso8601() - def _save_wav_file(self, audio_data: bytes, file_path: Path, sample_rate: int, num_channels: int) -> None: - """Save audio data to a WAV file. - - Args: - audio_data: Raw audio bytes (16-bit PCM) - file_path: Path to save the WAV file - sample_rate: Sample rate in Hz - num_channels: Number of channels (1=mono, 2=stereo) - """ - try: - with wave.open(str(file_path), "wb") as wav_file: - wav_file.setnchannels(num_channels) - wav_file.setsampwidth(2) # 16-bit PCM - wav_file.setframerate(sample_rate) - wav_file.writeframes(audio_data) - logger.debug(f"Audio saved to {file_path} ({len(audio_data)} bytes)") - except Exception as e: - logger.error(f"Error saving audio to {file_path}: {e}") - - def _save_audio(self) -> None: - """Save accumulated audio to WAV file.""" - if not self._audio_buffer: - logger.warning("No audio data to save") - return - - audio_path = self.output_dir / "audio_mixed.wav" - self._save_wav_file( - bytes(self._audio_buffer), - audio_path, - self._audio_sample_rate, - 1, # Mono - ) - user_audio_path = self.output_dir / "audio_user.wav" - self._save_wav_file( - bytes(self.user_audio_buffer), - user_audio_path, - self._audio_sample_rate, - 1, # Mono - ) - assistant_audio_path = self.output_dir / "audio_assistant.wav" - self._save_wav_file( - bytes(self.assistant_audio_buffer), - assistant_audio_path, - self._audio_sample_rate, - 1, # Mono - ) - logger.info(f"Saved {len(self._audio_buffer)} bytes of audio to {audio_path}") - - async def _save_outputs(self) -> None: - """Save all outputs (audit log, audio files, etc.).""" - # Save audit log - audit_path = self.output_dir / "audit_log.json" - self.audit_log.save(audit_path) - - # Save transcript from audit log. - # When using the instrumented realtime pipeline, always overwrite the - # eagerly-written transcript.jsonl with a version derived from the - # (correctly ordered) audit log. - transcript_path = self.output_dir / "transcript.jsonl" - if isinstance(self.pipeline_config, SpeechToSpeechConfig): - self.audit_log.save_transcript_jsonl(transcript_path) - elif not transcript_path.exists(): - self.audit_log.save_transcript_jsonl(transcript_path) - - # Save agent performance stats + async def save_outputs(self) -> None: + """Save all outputs, with pipecat-specific additions.""" + # Save agent performance stats (pipecat-specific: AgenticSystem tracking) if self.agentic_system: try: - logger.info("Saving agent performance stats from _save_outputs()...") + logger.info("Saving agent performance stats from save_outputs()...") self.agentic_system.save_agent_perf_stats() except Exception as e: logger.error(f"Error saving agent perf stats: {e}", exc_info=True) - # Save accumulated audio files - self._save_audio() - - # Save initial and final scenario database states (REQUIRED for deterministic metrics) - try: - initial_db = self.get_initial_scenario_db() - final_db = self.get_final_scenario_db() - - initial_db_path = self.output_dir / "initial_scenario_db.json" - with open(initial_db_path, "w") as f: - json.dump(initial_db, f, indent=2, sort_keys=True, default=str) - - final_db_path = self.output_dir / "final_scenario_db.json" - with open(final_db_path, "w") as f: - json.dump(final_db, f, indent=2, sort_keys=True, default=str) - - logger.info(f"Saved scenario database states to {self.output_dir}") - except Exception as e: - logger.error(f"Error saving scenario database states: {e}", exc_info=True) - raise # Re-raise since this is now required for deterministic metrics - - logger.info(f"Outputs saved to {self.output_dir}") - - def get_conversation_stats(self) -> dict[str, Any]: - """Get statistics about the conversation.""" - return self.audit_log.get_stats() - - def get_initial_scenario_db(self) -> dict[str, Any]: - """Get initial scenario database state.""" - return self.tool_handler.original_db - - def get_final_scenario_db(self) -> dict[str, Any]: - """Get final scenario database state.""" - return self.tool_handler.db + # Call base class to save audit_log, audio, scenario DBs, latencies + await super().save_outputs() async def override__maybe_trigger_user_turn_stopped(self): diff --git a/src/eva/assistant/pipeline/audio_llm_processor.py b/src/eva/assistant/pipeline/audio_llm_processor.py index a9154d4e..829c80b6 100644 --- a/src/eva/assistant/pipeline/audio_llm_processor.py +++ b/src/eva/assistant/pipeline/audio_llm_processor.py @@ -52,7 +52,7 @@ logger = get_logger(__name__) -# Pipeline sample rate (matches server.py SAMPLE_RATE) +# Pipeline sample rate (matches pipecat_server.py SAMPLE_RATE) PIPELINE_SAMPLE_RATE = 24000 # Minimum audio size to process (< 10ms of 24kHz 16-bit mono is noise/empty) @@ -199,7 +199,7 @@ def __init__( self._current_query_task: Optional[asyncio.Task] = None self._interrupted = asyncio.Event() - # Optional callback for transcript saving (set by server.py) + # Optional callback for transcript saving (set by pipecat_server.py) self.on_assistant_response: Optional[Awaitable] = None async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: @@ -234,7 +234,7 @@ async def _start_interruption(self): async def process_complete_user_turn(self, text_from_aggregator: str) -> None: """Process a complete user turn with audio. - Called by the on_user_turn_stopped event handler in server.py. + Called by the on_user_turn_stopped event handler in pipecat_server.py. The text_from_aggregator is typically empty since there is no STT; Args: @@ -426,7 +426,7 @@ def __init__( base_url, _transcription_url_counter = _resolve_url(params, _transcription_url_counter) self._client: AsyncOpenAI = AsyncOpenAI(api_key=self._api_key, base_url=base_url) - # Callback for when transcription is ready (set by server.py) + # Callback for when transcription is ready (set by pipecat_server.py) self.on_transcription: Optional[Any] = None # Track background transcription tasks so they can complete even during interruptions diff --git a/src/eva/models/config.py b/src/eva/models/config.py index 474d29a8..b1e85802 100644 --- a/src/eva/models/config.py +++ b/src/eva/models/config.py @@ -281,6 +281,19 @@ class ModelDeployment(DeploymentTypedDict): description="Pipeline (STT + LLM + TTS), speech-to-speech, or audio-LLM model configuration", ) + # Framework selection + framework: Literal["pipecat", "openai_realtime", "gemini_live", "elevenlabs", "deepgram"] = Field( + "pipecat", + description=( + "Agent framework to use for the assistant server. " + "'pipecat' (default): Pipecat pipeline. " + "'openai_realtime': OpenAI Realtime API directly. " + "'gemini_live': Gemini Live API via google-genai. " + "'elevenlabs': ElevenLabs Conversational AI. " + "'deepgram': Deepgram Voice Agent API." + ), + ) + # Run identifier run_id: str = Field( default_factory=current_date_and_time, diff --git a/src/eva/orchestrator/worker.py b/src/eva/orchestrator/worker.py index bf549d6b..5d9095c7 100644 --- a/src/eva/orchestrator/worker.py +++ b/src/eva/orchestrator/worker.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Optional -from eva.assistant.server import AssistantServer +from eva.assistant.base_server import AbstractAssistantServer from eva.models.agents import AgentConfig from eva.models.config import RunConfig from eva.models.record import EvaluationRecord @@ -20,6 +20,38 @@ logger = get_logger(__name__) +def _get_server_class(framework: str) -> type[AbstractAssistantServer]: + """Return the server class for the given framework name. + + Uses lazy imports to avoid importing heavy dependencies (pipecat, openai, etc.) + unless the framework is actually selected. + """ + if framework == "pipecat": + from eva.assistant.pipecat_server import PipecatAssistantServer + + return PipecatAssistantServer + elif framework == "openai_realtime": + from eva.assistant.openai_realtime_server import OpenAIRealtimeAssistantServer + + return OpenAIRealtimeAssistantServer + elif framework == "gemini_live": + from eva.assistant.gemini_live_server import GeminiLiveAssistantServer + + return GeminiLiveAssistantServer + elif framework == "elevenlabs": + from eva.assistant.elevenlabs_server import ElevenLabsAssistantServer + + return ElevenLabsAssistantServer + elif framework == "deepgram": + from eva.assistant.deepgram_server import DeepgramAssistantServer + + return DeepgramAssistantServer + else: + raise ValueError( + f"Unknown framework: {framework!r}. Supported: pipecat, openai_realtime, gemini_live, elevenlabs, deepgram" + ) + + def _percentile(sorted_data: list[float], p: float) -> float: """Calculate the p-th percentile using the nearest-rank method. @@ -223,7 +255,7 @@ async def run(self) -> ConversationResult: transcript_path=str(self.output_dir / "transcript.jsonl"), audit_log_path=str(self.output_dir / "audit_log.json"), conversation_log_path=str(self.output_dir / "logs.log"), - pipecat_logs_path=str(self.output_dir / "pipecat_logs.jsonl"), + pipecat_logs_path=self._resolve_framework_logs_path(), elevenlabs_logs_path=str(self.output_dir / "elevenlabs_events.jsonl"), num_turns=self._conversation_stats.get("num_turns", 0), num_tool_calls=self._conversation_stats.get("num_tool_calls", 0), @@ -234,8 +266,9 @@ async def run(self) -> ConversationResult: ) async def _start_assistant(self) -> None: - """Start the assistant server.""" - self._assistant_server = AssistantServer( + """Start the assistant server using the configured framework.""" + server_cls = _get_server_class(self.config.framework) + self._assistant_server = server_cls( current_date_time=self.record.current_date_time, pipeline_config=self.config.model, agent=self.agent, @@ -276,6 +309,14 @@ async def _run_conversation(self) -> str: return ended_reason + def _resolve_framework_logs_path(self) -> str: + """Resolve the framework/pipecat logs path, preferring framework_logs.jsonl.""" + framework_path = self.output_dir / "framework_logs.jsonl" + pipecat_path = self.output_dir / "pipecat_logs.jsonl" + if framework_path.exists(): + return str(framework_path) + return str(pipecat_path) + async def _cleanup(self) -> None: """Clean up resources.""" if self._assistant_server: diff --git a/src/eva/user_simulator/client.py b/src/eva/user_simulator/client.py index 51946a48..380c1277 100644 --- a/src/eva/user_simulator/client.py +++ b/src/eva/user_simulator/client.py @@ -231,6 +231,11 @@ async def _run_elevenlabs_conversation(self, api_key: str) -> str: except Exception as e: logger.warning(f"Failed to check conversation history for end_call: {e}") + try: + await self._fetch_elevenlabs_audio(conversation_id) + except Exception as e: + logger.warning(f"Failed to fetch ElevenLabs server audio: {e}") + self.event_logger.log_connection_state("session_ended", {"reason": self._end_reason}) except Exception as e: @@ -295,6 +300,27 @@ async def _check_end_call_via_api(self, conversation_id: str) -> bool: logger.warning(f"Conversation transcript still empty after {max_attempts} attempts") return False + async def _fetch_elevenlabs_audio(self, conversation_id: str) -> None: + max_attempts = 5 + delay = 2.0 + + for attempt in range(max_attempts): + try: + audio_iter = self._client.conversational_ai.conversations.audio.get(conversation_id) + audio_path = self.output_dir / "elevenlabs_audio_recording.mp3" + with open(audio_path, "wb") as f: + for chunk in audio_iter: + f.write(chunk) + logger.info(f"Saved ElevenLabs server-side audio to {audio_path}") + return + except Exception as e: + if attempt < max_attempts - 1: + logger.debug(f"Audio not yet available (attempt {attempt + 1}/{max_attempts}): {e}") + await asyncio.sleep(delay) + delay = min(delay * 2, 10.0) + else: + logger.warning(f"Failed to fetch ElevenLabs server audio after {max_attempts} attempts: {e}") + def _reset_keepalive_counter(self) -> None: """Reset the consecutive keep-alive counter on user/agent activity.""" self._consecutive_keepalive_count = 0 @@ -388,7 +414,7 @@ def _on_assistant_speaks(self, transcript: str) -> None: transcript: The text that the assistant said """ self._reset_keepalive_counter() - logger.info(f"🤖 Assistant (Pipecat): {transcript}") + logger.info(f"🤖 Assistant: {transcript}") self.event_logger.log_event( "assistant_speech", diff --git a/src/eva/utils/conversation_checks.py b/src/eva/utils/conversation_checks.py index d284549e..d77993de 100644 --- a/src/eva/utils/conversation_checks.py +++ b/src/eva/utils/conversation_checks.py @@ -63,7 +63,9 @@ def find_records_with_llm_generic_error(output_dir: Path, record_ids: set[str] | """Find records that have the LLM generic error message in pipecat_logs.jsonl.""" affected = [] for record_id in record_ids: - pipecat_logs_path = output_dir / "records" / record_id / "pipecat_logs.jsonl" + pipecat_logs_path = output_dir / "records" / record_id / "framework_logs.jsonl" + if not pipecat_logs_path.exists(): + pipecat_logs_path = output_dir / "records" / record_id / "pipecat_logs.jsonl" if not pipecat_logs_path.exists(): continue with open(pipecat_logs_path) as f: diff --git a/tests/unit/assistant/test_audio_bridge.py b/tests/unit/assistant/test_audio_bridge.py new file mode 100644 index 00000000..117c7a22 --- /dev/null +++ b/tests/unit/assistant/test_audio_bridge.py @@ -0,0 +1,104 @@ +"""Tests for shared audio bridge utilities. + +Covers: PCM↔mulaw round-trip fidelity, PCM16 mixing with clipping, +and Twilio WebSocket protocol message round-trips. +""" + +import audioop +import json +import math +import struct + +import pytest + +from eva.assistant.audio_bridge import ( + create_twilio_media_message, + mulaw_8k_to_pcm16_16k, + mulaw_8k_to_pcm16_24k, + parse_twilio_media_message, + pcm16_16k_to_mulaw_8k, + pcm16_24k_to_mulaw_8k, + pcm16_mix, +) + + +def _generate_mulaw_tone(freq_hz: int = 440, duration_ms: int = 100) -> bytes: + sample_rate = 8000 + n_samples = sample_rate * duration_ms // 1000 + pcm_samples = [int(16000 * math.sin(2 * math.pi * freq_hz * i / sample_rate)) for i in range(n_samples)] + pcm_bytes = struct.pack(f"<{n_samples}h", *pcm_samples) + return audioop.lin2ulaw(pcm_bytes, 2) + + +def _rms(pcm_bytes: bytes) -> float: + n = len(pcm_bytes) // 2 + if n == 0: + return 0.0 + samples = struct.unpack(f"<{n}h", pcm_bytes) + return math.sqrt(sum(s * s for s in samples) / n) + + +class TestAudioConversionRoundTrip: + def test_mulaw_8k_pcm16_24k_round_trip(self): + """Mulaw 8k -> pcm16 24k -> mulaw 8k preserves signal energy.""" + original = _generate_mulaw_tone(440, 100) + + pcm_24k = mulaw_8k_to_pcm16_24k(original) + recovered = pcm16_24k_to_mulaw_8k(pcm_24k) + + assert len(recovered) == len(original) + + orig_pcm = audioop.ulaw2lin(original, 2) + recov_pcm = audioop.ulaw2lin(recovered, 2) + orig_rms = _rms(orig_pcm) + recov_rms = _rms(recov_pcm) + assert orig_rms > 0 + assert recov_rms / orig_rms == pytest.approx(1.0, abs=0.15) + + def test_mulaw_8k_pcm16_16k_round_trip(self): + """Mulaw 8k -> pcm16 16k -> mulaw 8k preserves signal energy.""" + original = _generate_mulaw_tone(440, 100) + + pcm_16k = mulaw_8k_to_pcm16_16k(original) + recovered = pcm16_16k_to_mulaw_8k(pcm_16k) + + assert len(recovered) == len(original) + + orig_pcm = audioop.ulaw2lin(original, 2) + recov_pcm = audioop.ulaw2lin(recovered, 2) + orig_rms = _rms(orig_pcm) + recov_rms = _rms(recov_pcm) + assert orig_rms > 0 + assert recov_rms / orig_rms == pytest.approx(1.0, abs=0.15) + + +class TestPcm16Mix: + def test_adds_samples_and_clips_at_int16_boundaries(self): + """Sample-wise addition with clipping; shorter track is zero-padded.""" + track_a = struct.pack("<2h", 30000, -30000) + track_b = struct.pack("<2h", 10000, -10000) + + mixed = pcm16_mix(track_a, track_b) + result = struct.unpack("<2h", mixed) + assert result == (32767, -32768) + + short_track = struct.pack("<1h", 5000) + long_track = struct.pack("<2h", 100, 200) + mixed = pcm16_mix(short_track, long_track) + result = struct.unpack("<2h", mixed) + assert result == (5100, 200) + + +class TestTwilioProtocol: + def test_create_and_parse_round_trip(self): + """create_twilio_media_message -> parse_twilio_media_message recovers bytes.""" + audio = b"\x80\x90\xa0\xb0\xc0" + msg = create_twilio_media_message("stream-1", audio) + recovered = parse_twilio_media_message(msg) + assert recovered == audio + + parsed = json.loads(msg) + assert parsed["streamSid"] == "stream-1" + + assert parse_twilio_media_message(json.dumps({"event": "start"})) is None + assert parse_twilio_media_message("not json at all {{{") is None diff --git a/tests/unit/assistant/test_server.py b/tests/unit/assistant/test_pipecat_server.py similarity index 96% rename from tests/unit/assistant/test_server.py rename to tests/unit/assistant/test_pipecat_server.py index 6b2a9f14..effdae64 100644 --- a/tests/unit/assistant/test_server.py +++ b/tests/unit/assistant/test_pipecat_server.py @@ -1,4 +1,4 @@ -"""Tests for AssistantServer.""" +"""Tests for PipecatPipecatAssistantServer.""" import asyncio import json @@ -9,12 +9,12 @@ import pytest from eva.assistant.agentic.audit_log import AuditLog -from eva.assistant.server import SAMPLE_RATE, AssistantServer +from eva.assistant.pipecat_server import SAMPLE_RATE, PipecatAssistantServer def _make_server(tmp_path: Path): - """Build a lightweight AssistantServer without invoking __init__ (avoids Pipecat I/O).""" - srv = object.__new__(AssistantServer) + """Build a lightweight PipecatAssistantServer without invoking __init__ (avoids Pipecat I/O).""" + srv = object.__new__(PipecatAssistantServer) srv.output_dir = tmp_path srv.audit_log = AuditLog() srv.agentic_system = None @@ -157,7 +157,7 @@ async def test_saves_audit_log_and_both_scenario_db_snapshots(self, tmp_path): # Add an entry so audit_log is non-trivial srv.audit_log.append_user_input("Hello") - await srv._save_outputs() + await srv.save_outputs() # Audit log contains our entry audit = json.loads((tmp_path / "audit_log.json").read_text()) @@ -181,7 +181,7 @@ async def test_saves_agent_perf_stats_when_agentic_system_present(self, tmp_path mock_system = MagicMock() srv.agentic_system = mock_system - await srv._save_outputs() + await srv.save_outputs() mock_system.save_agent_perf_stats.assert_called_once()