|
8 | 8 | Annotated, |
9 | 9 | Any, |
10 | 10 | AsyncGenerator, |
| 11 | + Awaitable, |
| 12 | + Callable, |
11 | 13 | Generic, |
12 | 14 | Sequence, |
13 | 15 | TypeVar, |
|
55 | 57 | DurationSummary, |
56 | 58 | EndOfTurnItem, |
57 | 59 | FileSource, |
| 60 | + GeneratedImage, |
| 61 | + GeneratedImageItem, |
58 | 62 | HiddenContextItem, |
59 | 63 | SDKHiddenContextItem, |
60 | 64 | Task, |
@@ -105,6 +109,7 @@ class AgentContext(BaseModel, Generic[TContext]): |
105 | 109 | previous_response_id: str | None = None |
106 | 110 | client_tool_call: ClientToolCall | None = None |
107 | 111 | workflow_item: WorkflowItem | None = None |
| 112 | + generated_image_item: GeneratedImageItem | None = None |
108 | 113 | _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue() |
109 | 114 |
|
110 | 115 | def generate_id( |
@@ -357,7 +362,10 @@ class StreamingThoughtTracker(BaseModel): |
357 | 362 |
|
358 | 363 |
|
359 | 364 | async def stream_agent_response( |
360 | | - context: AgentContext, result: RunResultStreaming |
| 365 | + context: AgentContext, |
| 366 | + result: RunResultStreaming, |
| 367 | + *, |
| 368 | + base64_to_generated_image: Callable[[str], Awaitable[GeneratedImage]] | None = None, |
361 | 369 | ) -> AsyncIterator[ThreadStreamEvent]: |
362 | 370 | """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents.""" |
363 | 371 | current_item_id = None |
@@ -527,6 +535,15 @@ def end_workflow(item: WorkflowItem): |
527 | 535 | created_at=datetime.now(), |
528 | 536 | ), |
529 | 537 | ) |
| 538 | + elif item.type == "image_generation_call": |
| 539 | + ctx.generated_image_item = GeneratedImageItem( |
| 540 | + id=ctx.generate_id("message"), |
| 541 | + thread_id=thread.id, |
| 542 | + created_at=datetime.now(), |
| 543 | + image=None, |
| 544 | + ) |
| 545 | + produced_items.add(ctx.generated_image_item.id) |
| 546 | + yield ThreadItemAddedEvent(item=ctx.generated_image_item) |
530 | 547 | elif event.type == "response.reasoning_summary_text.delta": |
531 | 548 | if not ctx.workflow_item: |
532 | 549 | continue |
@@ -604,6 +621,22 @@ def end_workflow(item: WorkflowItem): |
604 | 621 | created_at=datetime.now(), |
605 | 622 | ), |
606 | 623 | ) |
| 624 | + elif item.type == "image_generation_call" and item.result: |
| 625 | + if not ctx.generated_image_item: |
| 626 | + continue |
| 627 | + |
| 628 | + # Agents SDK only produces png and base64 output currently. |
| 629 | + if base64_to_generated_image: |
| 630 | + image = await base64_to_generated_image(item.result) |
| 631 | + else: |
| 632 | + image = GeneratedImage( |
| 633 | + id=item.id, url=f"data:image/png;base64,{item.result}" |
| 634 | + ) |
| 635 | + |
| 636 | + ctx.generated_image_item.image = image |
| 637 | + yield ThreadItemDoneEvent(item=ctx.generated_image_item) |
| 638 | + |
| 639 | + ctx.generated_image_item = None |
607 | 640 |
|
608 | 641 | except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): |
609 | 642 | for item_id in produced_items: |
@@ -694,6 +727,17 @@ async def tag_to_message_content( |
694 | 727 | "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented" |
695 | 728 | ) |
696 | 729 |
|
| 730 | + async def generated_image_to_input( |
| 731 | + self, item: GeneratedImageItem |
| 732 | + ) -> TResponseInputItem | list[TResponseInputItem] | None: |
| 733 | + """ |
| 734 | + Convert a GeneratedImageItem into a message content part to send to the model. |
| 735 | + Required when generated images are enabled. |
| 736 | + """ |
| 737 | + raise NotImplementedError( |
| 738 | + "A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented" |
| 739 | + ) |
| 740 | + |
697 | 741 | async def hidden_context_to_input( |
698 | 742 | self, item: HiddenContextItem |
699 | 743 | ) -> TResponseInputItem | list[TResponseInputItem] | None: |
@@ -984,6 +1028,9 @@ async def _thread_item_to_input_item( |
984 | 1028 | case SDKHiddenContextItem(): |
985 | 1029 | out = await self.sdk_hidden_context_to_input(item) or [] |
986 | 1030 | return out if isinstance(out, list) else [out] |
| 1031 | + case GeneratedImageItem(): |
| 1032 | + out = await self.generated_image_to_input(item) or [] |
| 1033 | + return out if isinstance(out, list) else [out] |
987 | 1034 | case _: |
988 | 1035 | assert_never(item) |
989 | 1036 |
|
|
0 commit comments