Skip to content

Commit 0b5b5b5

Browse files
authored
Merge pull request #86 from openai/generated-images-item
Generated image support
2 parents d5565dc + bbc8b1d commit 0b5b5b5

File tree

3 files changed

+423
-2
lines changed

3 files changed

+423
-2
lines changed

chatkit/agents.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
EasyInputMessageParam,
2626
ResponseFunctionToolCallParam,
2727
ResponseInputContentParam,
28+
ResponseInputImageParam,
2829
ResponseInputMessageContentListParam,
2930
ResponseInputTextParam,
3031
ResponseOutputText,
@@ -55,6 +56,9 @@
5556
DurationSummary,
5657
EndOfTurnItem,
5758
FileSource,
59+
GeneratedImage,
60+
GeneratedImageItem,
61+
GeneratedImageUpdated,
5862
HiddenContextItem,
5963
SDKHiddenContextItem,
6064
Task,
@@ -105,6 +109,7 @@ class AgentContext(BaseModel, Generic[TContext]):
105109
previous_response_id: str | None = None
106110
client_tool_call: ClientToolCall | None = None
107111
workflow_item: WorkflowItem | None = None
112+
generated_image_item: GeneratedImageItem | None = None
108113
_events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
109114

110115
def generate_id(
@@ -356,10 +361,97 @@ class StreamingThoughtTracker(BaseModel):
356361
task: ThoughtTask
357362

358363

364+
class ResponseStreamConverter:
365+
"""Used by `stream_agent_response` to convert streamed Agents SDK output
366+
into values used by ChatKit thread items and thread stream events.
367+
368+
Defines overridable methods for adapting streamed data (such as image
369+
generation results and partial updates) into the forms expected by ChatKit.
370+
"""
371+
372+
partial_images: int | None = None
373+
"""
374+
The expected number of partial image updates for an image generation result.
375+
376+
When set, this value is used to normalize partial image indices into a
377+
progress value in the range [0, 1]. If unset, all partial image updates are
378+
assigned a progress value of 0.
379+
"""
380+
381+
def __init__(self, *, partial_images: int | None = None):
382+
"""
383+
Args:
384+
partial_images: The expected number of partial image updates for image
385+
generation results, or None if no progress normalization should
386+
be performed.
387+
"""
388+
self.partial_images = partial_images
389+
390+
async def base64_image_to_url(
391+
self,
392+
image_id: str,
393+
base64_image: str,
394+
partial_image_index: int | None = None,
395+
) -> str:
396+
"""
397+
Convert a base64-encoded image into a URL.
398+
399+
This method is used to produce the URL stored on thread items for image
400+
generation results.
401+
402+
Args:
403+
image_id: The ID of the image generation call. This stays stable across partial image updates.
404+
base64_image: The base64-encoded image.
405+
partial_image_index: The index of the partial image update, starting from 0.
406+
407+
Returns:
408+
A URL string.
409+
"""
410+
return f"data:image/png;base64,{base64_image}"
411+
412+
def partial_image_index_to_progress(self, partial_image_index: int) -> float:
413+
"""
414+
Convert a partial image index into a normalized progress value.
415+
416+
Args:
417+
partial_image_index: The index of the partial image update, starting from 0.
418+
419+
Returns:
420+
A float between 0 and 1 representing progress for the image
421+
generation result.
422+
"""
423+
if self.partial_images is None or self.partial_images <= 0:
424+
return 0.0
425+
426+
return min(1.0, partial_image_index / self.partial_images)
427+
428+
429+
_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
430+
431+
359432
async def stream_agent_response(
360-
context: AgentContext, result: RunResultStreaming
433+
context: AgentContext,
434+
result: RunResultStreaming,
435+
*,
436+
converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
361437
) -> AsyncIterator[ThreadStreamEvent]:
362-
"""Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents."""
438+
"""
439+
Convert a streamed Agents SDK run into ChatKit thread stream events.
440+
441+
This function consumes a streaming run result and yields `ThreadStreamEvent`
442+
objects as the run progresses.
443+
444+
Args:
445+
context: The AgentContext to use for the stream.
446+
result: The RunResultStreaming to convert.
447+
image_generation_stream_converter: Controls how streamed image generation output
448+
is converted into URLs and progress updates. The default converter stores the
449+
generated base64 image and assigns a progress value of 0 to all partial image
450+
updates.
451+
452+
Returns:
453+
An async iterator that yields thread stream events representing the run result.
454+
"""
363455
current_item_id = None
364456
current_tool_call = None
365457
ctx = context
@@ -527,6 +619,38 @@ def end_workflow(item: WorkflowItem):
527619
created_at=datetime.now(),
528620
),
529621
)
622+
elif item.type == "image_generation_call":
623+
ctx.generated_image_item = GeneratedImageItem(
624+
id=ctx.generate_id("message"),
625+
thread_id=thread.id,
626+
created_at=datetime.now(),
627+
image=None,
628+
)
629+
produced_items.add(ctx.generated_image_item.id)
630+
yield ThreadItemAddedEvent(item=ctx.generated_image_item)
631+
elif event.type == "response.image_generation_call.partial_image":
632+
if not ctx.generated_image_item:
633+
continue
634+
635+
url = await converter.base64_image_to_url(
636+
image_id=event.item_id,
637+
base64_image=event.partial_image_b64,
638+
partial_image_index=event.partial_image_index,
639+
)
640+
progress = converter.partial_image_index_to_progress(
641+
event.partial_image_index
642+
)
643+
644+
ctx.generated_image_item.image = GeneratedImage(
645+
id=event.item_id, url=url
646+
)
647+
648+
yield ThreadItemUpdatedEvent(
649+
item_id=ctx.generated_image_item.id,
650+
update=GeneratedImageUpdated(
651+
image=ctx.generated_image_item.image, progress=progress
652+
),
653+
)
530654
elif event.type == "response.reasoning_summary_text.delta":
531655
if not ctx.workflow_item:
532656
continue
@@ -604,6 +728,20 @@ def end_workflow(item: WorkflowItem):
604728
created_at=datetime.now(),
605729
),
606730
)
731+
elif item.type == "image_generation_call" and item.result:
732+
if not ctx.generated_image_item:
733+
continue
734+
735+
url = await converter.base64_image_to_url(
736+
image_id=item.id,
737+
base64_image=item.result,
738+
)
739+
image = GeneratedImage(id=item.id, url=url)
740+
741+
ctx.generated_image_item.image = image
742+
yield ThreadItemDoneEvent(item=ctx.generated_image_item)
743+
744+
ctx.generated_image_item = None
607745

608746
except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
609747
for item_id in produced_items:
@@ -694,6 +832,33 @@ async def tag_to_message_content(
694832
"A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
695833
)
696834

835+
async def generated_image_to_input(
836+
self, item: GeneratedImageItem
837+
) -> TResponseInputItem | list[TResponseInputItem] | None:
838+
"""
839+
Convert a GeneratedImageItem into input item(s) to send to the model.
840+
Override this method to customize the conversion of generated images, such as when your
841+
generated image url is not publicly reachable.
842+
"""
843+
if not item.image:
844+
return None
845+
846+
return Message(
847+
type="message",
848+
content=[
849+
ResponseInputTextParam(
850+
type="input_text",
851+
text="The following image was generated by the agent.",
852+
),
853+
ResponseInputImageParam(
854+
type="input_image",
855+
detail="auto",
856+
image_url=item.image.url,
857+
),
858+
],
859+
role="user",
860+
)
861+
697862
async def hidden_context_to_input(
698863
self, item: HiddenContextItem
699864
) -> TResponseInputItem | list[TResponseInputItem] | None:
@@ -984,6 +1149,9 @@ async def _thread_item_to_input_item(
9841149
case SDKHiddenContextItem():
9851150
out = await self.sdk_hidden_context_to_input(item) or []
9861151
return out if isinstance(out, list) else [out]
1152+
case GeneratedImageItem():
1153+
out = await self.generated_image_to_input(item) or []
1154+
return out if isinstance(out, list) else [out]
9871155
case _:
9881156
assert_never(item)
9891157

chatkit/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,14 @@ class WorkflowTaskUpdated(BaseModel):
471471
task: Task
472472

473473

474+
class GeneratedImageUpdated(BaseModel):
475+
"""Event emitted when a generated image is updated."""
476+
477+
type: Literal["generated_image.updated"] = "generated_image.updated"
478+
image: GeneratedImage
479+
progress: float | None = None
480+
481+
474482
ThreadItemUpdate = (
475483
AssistantMessageContentPartAdded
476484
| AssistantMessageContentPartTextDelta
@@ -481,6 +489,7 @@ class WorkflowTaskUpdated(BaseModel):
481489
| WidgetRootUpdated
482490
| WorkflowTaskAdded
483491
| WorkflowTaskUpdated
492+
| GeneratedImageUpdated
484493
)
485494
"""Union of possible updates applied to thread items."""
486495

@@ -579,6 +588,20 @@ class WidgetItem(ThreadItemBase):
579588
copy_text: str | None = None
580589

581590

591+
class GeneratedImage(BaseModel):
592+
"""Generated image."""
593+
594+
id: str
595+
url: str
596+
597+
598+
class GeneratedImageItem(ThreadItemBase):
599+
"""Thread item containing a generated image."""
600+
601+
type: Literal["generated_image"] = "generated_image"
602+
image: GeneratedImage | None = None
603+
604+
582605
class TaskItem(ThreadItemBase):
583606
"""Thread item containing a task."""
584607

@@ -624,6 +647,7 @@ class SDKHiddenContextItem(ThreadItemBase):
624647
| AssistantMessageItem
625648
| ClientToolCallItem
626649
| WidgetItem
650+
| GeneratedImageItem
627651
| WorkflowItem
628652
| TaskItem
629653
| HiddenContextItem

0 commit comments

Comments
 (0)