Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,7 +2467,8 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
ignore_task_switch = True
# TODO(long): should we mark the function call as failed to notify the LLM?

new_agent_task = sanitized_out.agent_task
if sanitized_out.agent_task is not None:
new_agent_task = sanitized_out.agent_task

if new_agent_task and not ignore_task_switch:
fnc_executed_ev._handoff_required = True
Expand All @@ -2480,7 +2481,7 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
draining = True

tool_messages = new_calls + new_fnc_outputs
if fnc_executed_ev._reply_required:
if fnc_executed_ev._reply_required and not fnc_executed_ev._handoff_required:
chat_ctx.items.extend(tool_messages)

# refresh instructions in chat_ctx so that any update_instructions()
Expand Down Expand Up @@ -2992,7 +2993,8 @@ def _create_assistant_message(
)
ignore_task_switch = True

new_agent_task = sanitized_out.agent_task
if sanitized_out.agent_task is not None:
new_agent_task = sanitized_out.agent_task

if new_agent_task and not ignore_task_switch:
fnc_executed_ev._handoff_required = True
Expand Down Expand Up @@ -3029,6 +3031,7 @@ def _create_assistant_message(

if (
fnc_executed_ev._reply_required
and not fnc_executed_ev._handoff_required
and not self.llm.capabilities.auto_tool_reply_generation
):
self._rt_session.interrupt()
Expand Down
54 changes: 54 additions & 0 deletions tests/test_agent_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from unittest import mock

import pytest

Expand Down Expand Up @@ -65,6 +66,29 @@ async def on_user_turn_completed(self, turn_ctx: ChatContext, new_message: ChatM
await asyncio.sleep(self.on_user_turn_completed_delay)


class HandoffTargetAgent(Agent):
def __init__(self, entered_event: asyncio.Event) -> None:
super().__init__(instructions=("You are the target handoff agent."))
self._entered_event = entered_event

async def on_enter(self) -> None:
self._entered_event.set()


class HandoffSourceAgent(Agent):
def __init__(self, entered_event: asyncio.Event) -> None:
super().__init__(instructions=("You are a source agent that can hand off."))
self._entered_event = entered_event

@function_tool
async def switch_to_secondary(self) -> Agent:
return HandoffTargetAgent(self._entered_event)

@function_tool
async def save_data(self, value: str) -> str:
return f"saved:{value}"


SESSION_TIMEOUT = 60.0


Expand Down Expand Up @@ -215,6 +239,36 @@ async def test_tool_call() -> None:
assert chat_ctx_items[6].text_content == "The weather in Tokyo is sunny today."


async def test_handoff_and_reply_required_no_extra_old_agent_reply() -> None:
speed = 5.0
actions = FakeActions()
actions.add_user_speech(0.5, 2.0, "switch")
actions.add_llm(
content="",
tool_calls=[
FunctionToolCall(name="save_data", arguments='{"value": "x"}', call_id="1"),
FunctionToolCall(name="switch_to_secondary", arguments="{}", call_id="2"),
],
)

handoff_entered = asyncio.Event()
session = create_session(actions, speed_factor=speed)
agent = HandoffSourceAgent(handoff_entered)

tool_executed_events: list[FunctionToolsExecutedEvent] = []
session.on("function_tools_executed", tool_executed_events.append)

with mock.patch.object(session.llm, "chat", wraps=session.llm.chat) as mock_chat:
await asyncio.wait_for(run_session(session, agent), timeout=SESSION_TIMEOUT)

assert handoff_entered.is_set()
assert len(tool_executed_events) == 1
assert tool_executed_events[0].has_agent_handoff is True
assert tool_executed_events[0].has_tool_reply is True
# No extra old-agent reply generation after handoff.
assert mock_chat.call_count == 1


@pytest.mark.parametrize(
"resume_false_interruption, expected_interruption_time",
[
Expand Down
Loading