|
25 | 25 | EasyInputMessageParam, |
26 | 26 | ResponseFunctionToolCallParam, |
27 | 27 | ResponseInputContentParam, |
| 28 | + ResponseInputImageParam, |
28 | 29 | ResponseInputMessageContentListParam, |
29 | 30 | ResponseInputTextParam, |
30 | 31 | ResponseOutputText, |
|
55 | 56 | DurationSummary, |
56 | 57 | EndOfTurnItem, |
57 | 58 | FileSource, |
| 59 | + GeneratedImage, |
| 60 | + GeneratedImageItem, |
| 61 | + GeneratedImageUpdated, |
58 | 62 | HiddenContextItem, |
59 | 63 | SDKHiddenContextItem, |
60 | 64 | Task, |
@@ -105,6 +109,7 @@ class AgentContext(BaseModel, Generic[TContext]): |
105 | 109 | previous_response_id: str | None = None |
106 | 110 | client_tool_call: ClientToolCall | None = None |
107 | 111 | workflow_item: WorkflowItem | None = None |
| 112 | + generated_image_item: GeneratedImageItem | None = None |
108 | 113 | _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue() |
109 | 114 |
|
110 | 115 | def generate_id( |
@@ -356,10 +361,97 @@ class StreamingThoughtTracker(BaseModel): |
356 | 361 | task: ThoughtTask |
357 | 362 |
|
358 | 363 |
|
| 364 | +class ResponseStreamConverter: |
| 365 | + """Used by `stream_agent_response` to convert streamed Agents SDK output |
| 366 | + into values used by ChatKit thread items and thread stream events. |
| 367 | +
|
| 368 | + Defines overridable methods for adapting streamed data (such as image |
| 369 | + generation results and partial updates) into the forms expected by ChatKit. |
| 370 | + """ |
| 371 | + |
| 372 | + partial_images: int | None = None |
| 373 | + """ |
| 374 | + The expected number of partial image updates for an image generation result. |
| 375 | +
|
| 376 | + When set, this value is used to normalize partial image indices into a |
| 377 | + progress value in the range [0, 1]. If unset, all partial image updates are |
| 378 | + assigned a progress value of 0. |
| 379 | + """ |
| 380 | + |
| 381 | + def __init__(self, *, partial_images: int | None = None): |
| 382 | + """ |
| 383 | + Args: |
| 384 | + partial_images: The expected number of partial image updates for image |
| 385 | + generation results, or None if no progress normalization should |
| 386 | + be performed. |
| 387 | + """ |
| 388 | + self.partial_images = partial_images |
| 389 | + |
| 390 | + async def base64_image_to_url( |
| 391 | + self, |
| 392 | + image_id: str, |
| 393 | + base64_image: str, |
| 394 | + partial_image_index: int | None = None, |
| 395 | + ) -> str: |
| 396 | + """ |
| 397 | + Convert a base64-encoded image into a URL. |
| 398 | +
|
| 399 | + This method is used to produce the URL stored on thread items for image |
| 400 | + generation results. |
| 401 | +
|
| 402 | + Args: |
| 403 | + image_id: The ID of the image generation call. This stays stable across partial image updates. |
| 404 | + base64_image: The base64-encoded image. |
| 405 | + partial_image_index: The index of the partial image update, starting from 0. |
| 406 | +
|
| 407 | + Returns: |
| 408 | + A URL string. |
| 409 | + """ |
| 410 | + return f"data:image/png;base64,{base64_image}" |
| 411 | + |
| 412 | + def partial_image_index_to_progress(self, partial_image_index: int) -> float: |
| 413 | + """ |
| 414 | + Convert a partial image index into a normalized progress value. |
| 415 | +
|
| 416 | + Args: |
| 417 | + partial_image_index: The index of the partial image update, starting from 0. |
| 418 | +
|
| 419 | + Returns: |
| 420 | + A float between 0 and 1 representing progress for the image |
| 421 | + generation result. |
| 422 | + """ |
| 423 | + if self.partial_images is None or self.partial_images <= 0: |
| 424 | + return 0.0 |
| 425 | + |
| 426 | + return min(1.0, partial_image_index / self.partial_images) |
| 427 | + |
| 428 | + |
| 429 | +_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter() |
| 430 | + |
| 431 | + |
359 | 432 | async def stream_agent_response( |
360 | | - context: AgentContext, result: RunResultStreaming |
| 433 | + context: AgentContext, |
| 434 | + result: RunResultStreaming, |
| 435 | + *, |
| 436 | + converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER, |
361 | 437 | ) -> AsyncIterator[ThreadStreamEvent]: |
362 | | - """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents.""" |
| 438 | + """ |
| 439 | + Convert a streamed Agents SDK run into ChatKit thread stream events. |
| 440 | +
|
| 441 | + This function consumes a streaming run result and yields `ThreadStreamEvent` |
| 442 | + objects as the run progresses. |
| 443 | +
|
| 444 | + Args: |
| 445 | + context: The AgentContext to use for the stream. |
| 446 | + result: The RunResultStreaming to convert. |
| 447 | + image_generation_stream_converter: Controls how streamed image generation output |
| 448 | + is converted into URLs and progress updates. The default converter stores the |
| 449 | + generated base64 image and assigns a progress value of 0 to all partial image |
| 450 | + updates. |
| 451 | +
|
| 452 | + Returns: |
| 453 | + An async iterator that yields thread stream events representing the run result. |
| 454 | + """ |
363 | 455 | current_item_id = None |
364 | 456 | current_tool_call = None |
365 | 457 | ctx = context |
@@ -527,6 +619,38 @@ def end_workflow(item: WorkflowItem): |
527 | 619 | created_at=datetime.now(), |
528 | 620 | ), |
529 | 621 | ) |
| 622 | + elif item.type == "image_generation_call": |
| 623 | + ctx.generated_image_item = GeneratedImageItem( |
| 624 | + id=ctx.generate_id("message"), |
| 625 | + thread_id=thread.id, |
| 626 | + created_at=datetime.now(), |
| 627 | + image=None, |
| 628 | + ) |
| 629 | + produced_items.add(ctx.generated_image_item.id) |
| 630 | + yield ThreadItemAddedEvent(item=ctx.generated_image_item) |
| 631 | + elif event.type == "response.image_generation_call.partial_image": |
| 632 | + if not ctx.generated_image_item: |
| 633 | + continue |
| 634 | + |
| 635 | + url = await converter.base64_image_to_url( |
| 636 | + image_id=event.item_id, |
| 637 | + base64_image=event.partial_image_b64, |
| 638 | + partial_image_index=event.partial_image_index, |
| 639 | + ) |
| 640 | + progress = converter.partial_image_index_to_progress( |
| 641 | + event.partial_image_index |
| 642 | + ) |
| 643 | + |
| 644 | + ctx.generated_image_item.image = GeneratedImage( |
| 645 | + id=event.item_id, url=url |
| 646 | + ) |
| 647 | + |
| 648 | + yield ThreadItemUpdatedEvent( |
| 649 | + item_id=ctx.generated_image_item.id, |
| 650 | + update=GeneratedImageUpdated( |
| 651 | + image=ctx.generated_image_item.image, progress=progress |
| 652 | + ), |
| 653 | + ) |
530 | 654 | elif event.type == "response.reasoning_summary_text.delta": |
531 | 655 | if not ctx.workflow_item: |
532 | 656 | continue |
@@ -604,6 +728,20 @@ def end_workflow(item: WorkflowItem): |
604 | 728 | created_at=datetime.now(), |
605 | 729 | ), |
606 | 730 | ) |
| 731 | + elif item.type == "image_generation_call" and item.result: |
| 732 | + if not ctx.generated_image_item: |
| 733 | + continue |
| 734 | + |
| 735 | + url = await converter.base64_image_to_url( |
| 736 | + image_id=item.id, |
| 737 | + base64_image=item.result, |
| 738 | + ) |
| 739 | + image = GeneratedImage(id=item.id, url=url) |
| 740 | + |
| 741 | + ctx.generated_image_item.image = image |
| 742 | + yield ThreadItemDoneEvent(item=ctx.generated_image_item) |
| 743 | + |
| 744 | + ctx.generated_image_item = None |
607 | 745 |
|
608 | 746 | except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): |
609 | 747 | for item_id in produced_items: |
@@ -694,6 +832,33 @@ async def tag_to_message_content( |
694 | 832 | "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented" |
695 | 833 | ) |
696 | 834 |
|
| 835 | + async def generated_image_to_input( |
| 836 | + self, item: GeneratedImageItem |
| 837 | + ) -> TResponseInputItem | list[TResponseInputItem] | None: |
| 838 | + """ |
| 839 | + Convert a GeneratedImageItem into input item(s) to send to the model. |
| 840 | + Override this method to customize the conversion of generated images, such as when your |
| 841 | + generated image url is not publicly reachable. |
| 842 | + """ |
| 843 | + if not item.image: |
| 844 | + return None |
| 845 | + |
| 846 | + return Message( |
| 847 | + type="message", |
| 848 | + content=[ |
| 849 | + ResponseInputTextParam( |
| 850 | + type="input_text", |
| 851 | + text="The following image was generated by the agent.", |
| 852 | + ), |
| 853 | + ResponseInputImageParam( |
| 854 | + type="input_image", |
| 855 | + detail="auto", |
| 856 | + image_url=item.image.url, |
| 857 | + ), |
| 858 | + ], |
| 859 | + role="user", |
| 860 | + ) |
| 861 | + |
697 | 862 | async def hidden_context_to_input( |
698 | 863 | self, item: HiddenContextItem |
699 | 864 | ) -> TResponseInputItem | list[TResponseInputItem] | None: |
@@ -984,6 +1149,9 @@ async def _thread_item_to_input_item( |
984 | 1149 | case SDKHiddenContextItem(): |
985 | 1150 | out = await self.sdk_hidden_context_to_input(item) or [] |
986 | 1151 | return out if isinstance(out, list) else [out] |
| 1152 | + case GeneratedImageItem(): |
| 1153 | + out = await self.generated_image_to_input(item) or [] |
| 1154 | + return out if isinstance(out, list) else [out] |
987 | 1155 | case _: |
988 | 1156 | assert_never(item) |
989 | 1157 |
|
|
0 commit comments