|
55 | 55 | DurationSummary, |
56 | 56 | EndOfTurnItem, |
57 | 57 | FileSource, |
| 58 | + GeneratedImage, |
| 59 | + GeneratedImageItem, |
| 60 | + GeneratedImageUpdated, |
58 | 61 | HiddenContextItem, |
59 | 62 | SDKHiddenContextItem, |
60 | 63 | Task, |
@@ -105,6 +108,7 @@ class AgentContext(BaseModel, Generic[TContext]): |
105 | 108 | previous_response_id: str | None = None |
106 | 109 | client_tool_call: ClientToolCall | None = None |
107 | 110 | workflow_item: WorkflowItem | None = None |
| 111 | + generated_image_item: GeneratedImageItem | None = None |
108 | 112 | _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue() |
109 | 113 |
|
110 | 114 | def generate_id( |
@@ -356,10 +360,84 @@ class StreamingThoughtTracker(BaseModel): |
356 | 360 | task: ThoughtTask |
357 | 361 |
|
358 | 362 |
|
| 363 | +class ResponseStreamConverter: |
| 364 | + """Used by `stream_agent_response` to convert streamed Agents SDK output |
| 365 | + into values used by ChatKit thread items and thread stream events. |
| 366 | +
|
| 367 | + Defines overridable methods for adapting streamed data (such as image |
| 368 | + generation results and partial updates) into the forms expected by ChatKit. |
| 369 | + """ |
| 370 | + |
| 371 | + partial_images: int | None = None |
| 372 | + """ |
| 373 | + The expected number of partial image updates for an image generation result. |
| 374 | +
|
| 375 | + When set, this value is used to normalize partial image indices into a |
| 376 | + progress value in the range [0, 1]. If unset, all partial image updates are |
| 377 | + assigned a progress value of 0. |
| 378 | + """ |
| 379 | + |
| 380 | + def __init__(self, partial_images: int | None = None): |
| 381 | + """ |
| 382 | + Args: |
| 383 | + partial_images: The expected number of partial image updates for image |
| 384 | + generation results, or None if no progress normalization should |
| 385 | + be performed. |
| 386 | + """ |
| 387 | + self.partial_images = partial_images |
| 388 | + |
| 389 | + async def base64_image_to_url(self, base64_image: str) -> str: |
| 390 | + """ |
| 391 | + Convert a base64-encoded image into a URL. |
| 392 | +
|
| 393 | + This method is used to produce the URL stored on thread items for image |
| 394 | + generation results. |
| 395 | + """ |
| 396 | + return f"data:image/png;base64,{base64_image}" |
| 397 | + |
| 398 | + def partial_image_index_to_progress(self, partial_image_index: int) -> float: |
| 399 | + """ |
| 400 | + Convert a partial image index into a normalized progress value. |
| 401 | +
|
| 402 | + Args: |
| 403 | + partial_image_index: The index of the partial image update, starting from 0. |
| 404 | +
|
| 405 | + Returns: |
| 406 | + A float between 0 and 1 representing progress for the image |
| 407 | + generation result. |
| 408 | + """ |
| 409 | + if self.partial_images is None: |
| 410 | + return 0 |
| 411 | + |
| 412 | + return partial_image_index / self.partial_images |
| 413 | + |
| 414 | + |
| 415 | +_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter() |
| 416 | + |
| 417 | + |
359 | 418 | async def stream_agent_response( |
360 | | - context: AgentContext, result: RunResultStreaming |
| 419 | + context: AgentContext, |
| 420 | + result: RunResultStreaming, |
| 421 | + *, |
| 422 | + converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER, |
361 | 423 | ) -> AsyncIterator[ThreadStreamEvent]: |
362 | | - """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents.""" |
| 424 | + """ |
| 425 | + Convert a streamed Agents SDK run into ChatKit thread stream events. |
| 426 | +
|
| 427 | + This function consumes a streaming run result and yields `ThreadStreamEvent` |
| 428 | + objects as the run progresses. |
| 429 | +
|
| 430 | + Args: |
| 431 | + context: The AgentContext to use for the stream. |
| 432 | + result: The RunResultStreaming to convert. |
| 433 | + image_generation_stream_converter: Controls how streamed image generation output |
| 434 | + is converted into URLs and progress updates. The default converter stores the |
| 435 | + generated base64 image and assigns a progress value of 0 to all partial image |
| 436 | + updates. |
| 437 | +
|
| 438 | + Returns: |
| 439 | + An async iterator that yields thread stream events representing the run result. |
| 440 | + """ |
363 | 441 | current_item_id = None |
364 | 442 | current_tool_call = None |
365 | 443 | ctx = context |
@@ -527,6 +605,15 @@ def end_workflow(item: WorkflowItem): |
527 | 605 | created_at=datetime.now(), |
528 | 606 | ), |
529 | 607 | ) |
| 608 | + elif item.type == "image_generation_call": |
| 609 | + ctx.generated_image_item = GeneratedImageItem( |
| 610 | + id=ctx.generate_id("message"), |
| 611 | + thread_id=thread.id, |
| 612 | + created_at=datetime.now(), |
| 613 | + image=None, |
| 614 | + ) |
| 615 | + produced_items.add(ctx.generated_image_item.id) |
| 616 | + yield ThreadItemAddedEvent(item=ctx.generated_image_item) |
530 | 617 | elif event.type == "response.reasoning_summary_text.delta": |
531 | 618 | if not ctx.workflow_item: |
532 | 619 | continue |
@@ -604,6 +691,36 @@ def end_workflow(item: WorkflowItem): |
604 | 691 | created_at=datetime.now(), |
605 | 692 | ), |
606 | 693 | ) |
| 694 | + elif item.type == "image_generation_call" and item.result: |
| 695 | + if not ctx.generated_image_item: |
| 696 | + continue |
| 697 | + |
| 698 | + url = await converter.base64_image_to_url(item.result) |
| 699 | + image = GeneratedImage(id=item.id, url=url) |
| 700 | + |
| 701 | + ctx.generated_image_item.image = image |
| 702 | + yield ThreadItemDoneEvent(item=ctx.generated_image_item) |
| 703 | + |
| 704 | + ctx.generated_image_item = None |
| 705 | + elif event.type == "response.image_generation_call.partial_image": |
| 706 | + if not ctx.generated_image_item: |
| 707 | + continue |
| 708 | + |
| 709 | + url = await converter.base64_image_to_url(event.partial_image_b64) |
| 710 | + progress = converter.partial_image_index_to_progress( |
| 711 | + event.partial_image_index |
| 712 | + ) |
| 713 | + |
| 714 | + ctx.generated_image_item.image = GeneratedImage( |
| 715 | + id=event.item_id, url=url |
| 716 | + ) |
| 717 | + |
| 718 | + yield ThreadItemUpdatedEvent( |
| 719 | + item_id=ctx.generated_image_item.id, |
| 720 | + update=GeneratedImageUpdated( |
| 721 | + image=ctx.generated_image_item.image, progress=progress |
| 722 | + ), |
| 723 | + ) |
607 | 724 |
|
608 | 725 | except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): |
609 | 726 | for item_id in produced_items: |
@@ -694,6 +811,17 @@ async def tag_to_message_content( |
694 | 811 | "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented" |
695 | 812 | ) |
696 | 813 |
|
| 814 | + async def generated_image_to_input( |
| 815 | + self, item: GeneratedImageItem |
| 816 | + ) -> TResponseInputItem | list[TResponseInputItem] | None: |
| 817 | + """ |
| 818 | + Convert a GeneratedImageItem into input item(s) to send to the model. |
| 819 | + Required when generated images are enabled. |
| 820 | + """ |
| 821 | + raise NotImplementedError( |
| 822 | + "A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented" |
| 823 | + ) |
| 824 | + |
697 | 825 | async def hidden_context_to_input( |
698 | 826 | self, item: HiddenContextItem |
699 | 827 | ) -> TResponseInputItem | list[TResponseInputItem] | None: |
@@ -984,6 +1112,9 @@ async def _thread_item_to_input_item( |
984 | 1112 | case SDKHiddenContextItem(): |
985 | 1113 | out = await self.sdk_hidden_context_to_input(item) or [] |
986 | 1114 | return out if isinstance(out, list) else [out] |
| 1115 | + case GeneratedImageItem(): |
| 1116 | + out = await self.generated_image_to_input(item) or [] |
| 1117 | + return out if isinstance(out, list) else [out] |
987 | 1118 | case _: |
988 | 1119 | assert_never(item) |
989 | 1120 |
|
|
0 commit comments