-
Notifications
You must be signed in to change notification settings - Fork 95
Generated image support #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
35ee77f
9100b7a
51e0454
800a66c
f3a2b1f
dbf2538
bbc8b1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,9 @@ | |
| DurationSummary, | ||
| EndOfTurnItem, | ||
| FileSource, | ||
| GeneratedImage, | ||
| GeneratedImageItem, | ||
| GeneratedImageUpdated, | ||
| HiddenContextItem, | ||
| SDKHiddenContextItem, | ||
| Task, | ||
|
|
@@ -105,6 +108,7 @@ class AgentContext(BaseModel, Generic[TContext]): | |
| previous_response_id: str | None = None | ||
| client_tool_call: ClientToolCall | None = None | ||
| workflow_item: WorkflowItem | None = None | ||
| generated_image_item: GeneratedImageItem | None = None | ||
| _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue() | ||
|
|
||
| def generate_id( | ||
|
|
@@ -356,10 +360,97 @@ class StreamingThoughtTracker(BaseModel): | |
| task: ThoughtTask | ||
|
|
||
|
|
||
| class ResponseStreamConverter: | ||
| """Used by `stream_agent_response` to convert streamed Agents SDK output | ||
| into values used by ChatKit thread items and thread stream events. | ||
|
|
||
| Defines overridable methods for adapting streamed data (such as image | ||
| generation results and partial updates) into the forms expected by ChatKit. | ||
| """ | ||
|
|
||
| partial_images: int | None = None | ||
| """ | ||
| The expected number of partial image updates for an image generation result. | ||
|
|
||
| When set, this value is used to normalize partial image indices into a | ||
| progress value in the range [0, 1]. If unset, all partial image updates are | ||
| assigned a progress value of 0. | ||
| """ | ||
|
|
||
| def __init__(self, partial_images: int | None = None): | ||
| """ | ||
| Args: | ||
| partial_images: The expected number of partial image updates for image | ||
| generation results, or None if no progress normalization should | ||
| be performed. | ||
| """ | ||
| self.partial_images = partial_images | ||
|
|
||
| async def base64_image_to_url( | ||
| self, | ||
| image_id: str, | ||
| base64_image: str, | ||
| partial_image_index: int | None = None, | ||
| ) -> str: | ||
| """ | ||
| Convert a base64-encoded image into a URL. | ||
|
|
||
| This method is used to produce the URL stored on thread items for image | ||
| generation results. | ||
|
|
||
| Args: | ||
| image_id: The ID of the image generation call. This stays stable across partial image updates. | ||
| base64_image: The base64-encoded image. | ||
| partial_image_index: The index of the partial image update, starting from 0. | ||
|
|
||
| Returns: | ||
| A URL string. | ||
| """ | ||
| return f"data:image/png;base64,{base64_image}" | ||
|
|
||
| def partial_image_index_to_progress(self, partial_image_index: int) -> float: | ||
| """ | ||
| Convert a partial image index into a normalized progress value. | ||
|
|
||
| Args: | ||
| partial_image_index: The index of the partial image update, starting from 0. | ||
|
|
||
| Returns: | ||
| A float between 0 and 1 representing progress for the image | ||
| generation result. | ||
| """ | ||
| if self.partial_images is None: | ||
| return 0 | ||
|
|
||
| return partial_image_index / self.partial_images | ||
|
|
||
|
|
||
| _DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter() | ||
|
|
||
|
|
||
| async def stream_agent_response( | ||
| context: AgentContext, result: RunResultStreaming | ||
| context: AgentContext, | ||
| result: RunResultStreaming, | ||
| *, | ||
| converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER, | ||
| ) -> AsyncIterator[ThreadStreamEvent]: | ||
| """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents.""" | ||
| """ | ||
| Convert a streamed Agents SDK run into ChatKit thread stream events. | ||
|
|
||
| This function consumes a streaming run result and yields `ThreadStreamEvent` | ||
| objects as the run progresses. | ||
|
|
||
| Args: | ||
| context: The AgentContext to use for the stream. | ||
| result: The RunResultStreaming to convert. | ||
| image_generation_stream_converter: Controls how streamed image generation output | ||
| is converted into URLs and progress updates. The default converter stores the | ||
| generated base64 image and assigns a progress value of 0 to all partial image | ||
| updates. | ||
|
|
||
| Returns: | ||
| An async iterator that yields thread stream events representing the run result. | ||
| """ | ||
| current_item_id = None | ||
| current_tool_call = None | ||
| ctx = context | ||
|
|
@@ -527,6 +618,15 @@ def end_workflow(item: WorkflowItem): | |
| created_at=datetime.now(), | ||
| ), | ||
| ) | ||
| elif item.type == "image_generation_call": | ||
| ctx.generated_image_item = GeneratedImageItem( | ||
| id=ctx.generate_id("message"), | ||
| thread_id=thread.id, | ||
| created_at=datetime.now(), | ||
| image=None, | ||
| ) | ||
| produced_items.add(ctx.generated_image_item.id) | ||
| yield ThreadItemAddedEvent(item=ctx.generated_image_item) | ||
| elif event.type == "response.reasoning_summary_text.delta": | ||
| if not ctx.workflow_item: | ||
| continue | ||
|
|
@@ -604,6 +704,40 @@ def end_workflow(item: WorkflowItem): | |
| created_at=datetime.now(), | ||
| ), | ||
| ) | ||
| elif item.type == "image_generation_call" and item.result: | ||
| if not ctx.generated_image_item: | ||
| continue | ||
|
|
||
| url = await converter.base64_image_to_url(item.id, item.result) | ||
| image = GeneratedImage(id=item.id, url=url) | ||
|
|
||
| ctx.generated_image_item.image = image | ||
| yield ThreadItemDoneEvent(item=ctx.generated_image_item) | ||
|
|
||
| ctx.generated_image_item = None | ||
| elif event.type == "response.image_generation_call.partial_image": | ||
| if not ctx.generated_image_item: | ||
| continue | ||
|
|
||
| url = await converter.base64_image_to_url( | ||
| event.item_id, | ||
| event.partial_image_b64, | ||
| event.partial_image_index, | ||
| ) | ||
| progress = converter.partial_image_index_to_progress( | ||
| event.partial_image_index | ||
| ) | ||
|
|
||
| ctx.generated_image_item.image = GeneratedImage( | ||
| id=event.item_id, url=url | ||
| ) | ||
|
|
||
| yield ThreadItemUpdatedEvent( | ||
| item_id=ctx.generated_image_item.id, | ||
| update=GeneratedImageUpdated( | ||
| image=ctx.generated_image_item.image, progress=progress | ||
| ), | ||
| ) | ||
|
|
||
| except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): | ||
| for item_id in produced_items: | ||
|
|
@@ -694,6 +828,17 @@ async def tag_to_message_content( | |
| "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented" | ||
| ) | ||
|
|
||
| async def generated_image_to_input( | ||
| self, item: GeneratedImageItem | ||
| ) -> TResponseInputItem | list[TResponseInputItem] | None: | ||
| """ | ||
| Convert a GeneratedImageItem into input item(s) to send to the model. | ||
| Required when generated images are enabled. | ||
| """ | ||
| raise NotImplementedError( | ||
| "A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented" | ||
| ) | ||
|
Comment on lines
838
to
860
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm is there no way for us to represent images in a response? I'm not actually sure what the right thing to do here is, but can you just add image content by URL which we have?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, I'll add a default that returns a ResponseInputImageParam |
||
|
|
||
| async def hidden_context_to_input( | ||
| self, item: HiddenContextItem | ||
| ) -> TResponseInputItem | list[TResponseInputItem] | None: | ||
|
|
@@ -984,6 +1129,9 @@ async def _thread_item_to_input_item( | |
| case SDKHiddenContextItem(): | ||
| out = await self.sdk_hidden_context_to_input(item) or [] | ||
| return out if isinstance(out, list) else [out] | ||
| case GeneratedImageItem(): | ||
| out = await self.generated_image_to_input(item) or [] | ||
| return out if isinstance(out, list) else [out] | ||
| case _: | ||
| assert_never(item) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for readability should this block be above the done handlers?