|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +from unittest import mock |
4 | 5 |
|
5 | 6 | import pytest |
6 | 7 |
|
@@ -65,6 +66,29 @@ async def on_user_turn_completed(self, turn_ctx: ChatContext, new_message: ChatM |
65 | 66 | await asyncio.sleep(self.on_user_turn_completed_delay) |
66 | 67 |
|
67 | 68 |
|
| 69 | +class HandoffTargetAgent(Agent): |
| 70 | + def __init__(self, entered_event: asyncio.Event) -> None: |
| 71 | + super().__init__(instructions=("You are the target handoff agent.")) |
| 72 | + self._entered_event = entered_event |
| 73 | + |
| 74 | + async def on_enter(self) -> None: |
| 75 | + self._entered_event.set() |
| 76 | + |
| 77 | + |
| 78 | +class HandoffSourceAgent(Agent): |
| 79 | + def __init__(self, entered_event: asyncio.Event) -> None: |
| 80 | + super().__init__(instructions=("You are a source agent that can hand off.")) |
| 81 | + self._entered_event = entered_event |
| 82 | + |
| 83 | + @function_tool |
| 84 | + async def switch_to_secondary(self) -> Agent: |
| 85 | + return HandoffTargetAgent(self._entered_event) |
| 86 | + |
| 87 | + @function_tool |
| 88 | + async def save_data(self, value: str) -> str: |
| 89 | + return f"saved:{value}" |
| 90 | + |
| 91 | + |
68 | 92 | SESSION_TIMEOUT = 60.0 |
69 | 93 |
|
70 | 94 |
|
@@ -215,6 +239,36 @@ async def test_tool_call() -> None: |
215 | 239 | assert chat_ctx_items[6].text_content == "The weather in Tokyo is sunny today." |
216 | 240 |
|
217 | 241 |
|
| 242 | +async def test_handoff_and_reply_required_no_extra_old_agent_reply() -> None: |
| 243 | + speed = 5.0 |
| 244 | + actions = FakeActions() |
| 245 | + actions.add_user_speech(0.5, 2.0, "switch") |
| 246 | + actions.add_llm( |
| 247 | + content="", |
| 248 | + tool_calls=[ |
| 249 | + FunctionToolCall(name="save_data", arguments='{"value": "x"}', call_id="1"), |
| 250 | + FunctionToolCall(name="switch_to_secondary", arguments="{}", call_id="2"), |
| 251 | + ], |
| 252 | + ) |
| 253 | + |
| 254 | + handoff_entered = asyncio.Event() |
| 255 | + session = create_session(actions, speed_factor=speed) |
| 256 | + agent = HandoffSourceAgent(handoff_entered) |
| 257 | + |
| 258 | + tool_executed_events: list[FunctionToolsExecutedEvent] = [] |
| 259 | + session.on("function_tools_executed", tool_executed_events.append) |
| 260 | + |
| 261 | + with mock.patch.object(session.llm, "chat", wraps=session.llm.chat) as mock_chat: |
| 262 | + await asyncio.wait_for(run_session(session, agent), timeout=SESSION_TIMEOUT) |
| 263 | + |
| 264 | + assert handoff_entered.is_set() |
| 265 | + assert len(tool_executed_events) == 1 |
| 266 | + assert tool_executed_events[0].has_agent_handoff is True |
| 267 | + assert tool_executed_events[0].has_tool_reply is True |
| 268 | + # No extra old-agent reply generation after handoff. |
| 269 | + assert mock_chat.call_count == 1 |
| 270 | + |
| 271 | + |
218 | 272 | @pytest.mark.parametrize( |
219 | 273 | "resume_false_interruption, expected_interruption_time", |
220 | 274 | [ |
|
0 commit comments