Skip to content

Commit 573e483

Browse files
committed
improve the concurrency of event handling
1 parent 7f1672a commit 573e483

File tree

2 files changed

+116
-7
lines changed

2 files changed

+116
-7
lines changed

src/agents/agent.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
99

1010
from openai.types.responses.response_prompt_param import ResponsePromptParam
11-
from typing_extensions import NotRequired, TypeAlias, TypedDict
11+
from typing_extensions import NotRequired, TypedDict
1212

1313
from .agent_output import AgentOutputSchemaBase
1414
from .guardrail import InputGuardrail, OutputGuardrail
@@ -457,12 +457,11 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
457457
conversation_id=conversation_id,
458458
session=session,
459459
)
460-
async for event in run_result.stream_events():
461-
payload: AgentToolStreamEvent = {
462-
"event": event,
463-
"agent": self,
464-
"tool_call": getattr(context, "tool_call", None),
465-
}
460+
# Dispatch streaming callbacks in the background so slow handlers do not block event consumption.
461+
event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue()
462+
463+
async def _run_handler(payload: AgentToolStreamEvent) -> None:
464+
"""Execute the user callback while capturing exceptions."""
466465
try:
467466
maybe_result = on_stream(payload)
468467
if inspect.isawaitable(maybe_result):
@@ -472,6 +471,34 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
472471
"Error while handling on_stream event for agent tool %s.",
473472
self.name,
474473
)
474+
475+
async def dispatch_stream_events() -> None:
476+
while True:
477+
payload = await event_queue.get()
478+
is_sentinel = payload is None # None marks the end of the stream.
479+
try:
480+
if payload is not None:
481+
await _run_handler(payload)
482+
finally:
483+
event_queue.task_done()
484+
485+
if is_sentinel:
486+
break
487+
488+
dispatch_task = asyncio.create_task(dispatch_stream_events())
489+
490+
try:
491+
async for event in run_result.stream_events():
492+
payload: AgentToolStreamEvent = {
493+
"event": event,
494+
"agent": self,
495+
"tool_call": getattr(context, "tool_call", None),
496+
}
497+
await event_queue.put(payload)
498+
finally:
499+
await event_queue.put(None)
500+
await event_queue.join()
501+
await dispatch_task
475502
else:
476503
run_result = await Runner.run(
477504
starting_agent=self,

tests/test_agent_as_tool.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from typing import Any, cast
45

56
import pytest
@@ -612,6 +613,87 @@ def sync_handler(event: AgentToolStreamEvent) -> None:
612613
assert calls == ["raw_response_event"]
613614

614615

616+
@pytest.mark.asyncio
617+
async def test_agent_as_tool_streaming_dispatches_without_blocking(
618+
monkeypatch: pytest.MonkeyPatch,
619+
) -> None:
620+
"""on_stream handlers should not block streaming iteration."""
621+
agent = Agent(name="nonblocking_agent")
622+
623+
first_handler_started = asyncio.Event()
624+
allow_handler_to_continue = asyncio.Event()
625+
second_event_yielded = asyncio.Event()
626+
second_event_handled = asyncio.Event()
627+
628+
first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))
629+
second_event = RawResponsesStreamEvent(
630+
data=cast(Any, {"type": "output_text_delta", "delta": "hi"})
631+
)
632+
633+
class DummyStreamingResult:
634+
def __init__(self) -> None:
635+
self.final_output = "ok"
636+
637+
async def stream_events(self):
638+
yield first_event
639+
second_event_yielded.set()
640+
yield second_event
641+
642+
dummy_result = DummyStreamingResult()
643+
644+
monkeypatch.setattr(
645+
Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result)
646+
)
647+
monkeypatch.setattr(
648+
Runner,
649+
"run",
650+
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))),
651+
)
652+
653+
async def on_stream(payload: AgentToolStreamEvent) -> None:
654+
if payload["event"] is first_event:
655+
first_handler_started.set()
656+
await allow_handler_to_continue.wait()
657+
else:
658+
second_event_handled.set()
659+
660+
tool_call = ResponseFunctionToolCall(
661+
id="call_nonblocking",
662+
arguments='{"input": "go"}',
663+
call_id="call-nonblocking",
664+
name="nonblocking_tool",
665+
type="function_call",
666+
)
667+
668+
tool = cast(
669+
FunctionTool,
670+
agent.as_tool(
671+
tool_name="nonblocking_tool",
672+
tool_description="Uses non-blocking streaming handler",
673+
on_stream=on_stream,
674+
),
675+
)
676+
tool_context = ToolContext(
677+
context=None,
678+
tool_name="nonblocking_tool",
679+
tool_call_id=tool_call.call_id,
680+
tool_arguments=tool_call.arguments,
681+
tool_call=tool_call,
682+
)
683+
684+
invoke_task = asyncio.create_task(tool.on_invoke_tool(tool_context, '{"input": "go"}'))
685+
686+
await asyncio.wait_for(first_handler_started.wait(), timeout=1.0)
687+
await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0)
688+
assert invoke_task.done() is False
689+
690+
allow_handler_to_continue.set()
691+
await asyncio.wait_for(second_event_handled.wait(), timeout=1.0)
692+
output = await asyncio.wait_for(invoke_task, timeout=1.0)
693+
694+
assert output == "ok"
695+
696+
615697
@pytest.mark.asyncio
616698
async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call(
617699
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)