Skip to content

Commit 9100b7a

Browse files
committed
(2) stream_agent_response can handle image gen output
1 parent 35ee77f commit 9100b7a

File tree

1 file changed

+133
-2
lines changed

1 file changed

+133
-2
lines changed

chatkit/agents.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
DurationSummary,
5656
EndOfTurnItem,
5757
FileSource,
58+
GeneratedImage,
59+
GeneratedImageItem,
60+
GeneratedImageUpdated,
5861
HiddenContextItem,
5962
SDKHiddenContextItem,
6063
Task,
@@ -105,6 +108,7 @@ class AgentContext(BaseModel, Generic[TContext]):
105108
previous_response_id: str | None = None
106109
client_tool_call: ClientToolCall | None = None
107110
workflow_item: WorkflowItem | None = None
111+
generated_image_item: GeneratedImageItem | None = None
108112
_events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
109113

110114
def generate_id(
@@ -356,10 +360,84 @@ class StreamingThoughtTracker(BaseModel):
356360
task: ThoughtTask
357361

358362

363+
class ResponseStreamConverter:
364+
"""Used by `stream_agent_response` to convert streamed Agents SDK output
365+
into values used by ChatKit thread items and thread stream events.
366+
367+
Defines overridable methods for adapting streamed data (such as image
368+
generation results and partial updates) into the forms expected by ChatKit.
369+
"""
370+
371+
partial_images: int | None = None
372+
"""
373+
The expected number of partial image updates for an image generation result.
374+
375+
When set, this value is used to normalize partial image indices into a
376+
progress value in the range [0, 1]. If unset, all partial image updates are
377+
assigned a progress value of 0.
378+
"""
379+
380+
def __init__(self, partial_images: int | None = None):
381+
"""
382+
Args:
383+
partial_images: The expected number of partial image updates for image
384+
generation results, or None if no progress normalization should
385+
be performed.
386+
"""
387+
self.partial_images = partial_images
388+
389+
async def base64_image_to_url(self, base64_image: str) -> str:
390+
"""
391+
Convert a base64-encoded image into a URL.
392+
393+
This method is used to produce the URL stored on thread items for image
394+
generation results.
395+
"""
396+
return f"data:image/png;base64,{base64_image}"
397+
398+
def partial_image_index_to_progress(self, partial_image_index: int) -> float:
399+
"""
400+
Convert a partial image index into a normalized progress value.
401+
402+
Args:
403+
partial_image_index: The index of the partial image update, starting from 0.
404+
405+
Returns:
406+
A float between 0 and 1 representing progress for the image
407+
generation result.
408+
"""
409+
if self.partial_images is None:
410+
return 0
411+
412+
return partial_image_index / self.partial_images
413+
414+
415+
_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
416+
417+
359418
async def stream_agent_response(
360-
context: AgentContext, result: RunResultStreaming
419+
context: AgentContext,
420+
result: RunResultStreaming,
421+
*,
422+
converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
361423
) -> AsyncIterator[ThreadStreamEvent]:
362-
"""Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents."""
424+
"""
425+
Convert a streamed Agents SDK run into ChatKit thread stream events.
426+
427+
This function consumes a streaming run result and yields `ThreadStreamEvent`
428+
objects as the run progresses.
429+
430+
Args:
431+
context: The AgentContext to use for the stream.
432+
result: The RunResultStreaming to convert.
433+
image_generation_stream_converter: Controls how streamed image generation output
434+
is converted into URLs and progress updates. The default converter stores the
435+
generated base64 image and assigns a progress value of 0 to all partial image
436+
updates.
437+
438+
Returns:
439+
An async iterator that yields thread stream events representing the run result.
440+
"""
363441
current_item_id = None
364442
current_tool_call = None
365443
ctx = context
@@ -527,6 +605,15 @@ def end_workflow(item: WorkflowItem):
527605
created_at=datetime.now(),
528606
),
529607
)
608+
elif item.type == "image_generation_call":
609+
ctx.generated_image_item = GeneratedImageItem(
610+
id=ctx.generate_id("message"),
611+
thread_id=thread.id,
612+
created_at=datetime.now(),
613+
image=None,
614+
)
615+
produced_items.add(ctx.generated_image_item.id)
616+
yield ThreadItemAddedEvent(item=ctx.generated_image_item)
530617
elif event.type == "response.reasoning_summary_text.delta":
531618
if not ctx.workflow_item:
532619
continue
@@ -604,6 +691,36 @@ def end_workflow(item: WorkflowItem):
604691
created_at=datetime.now(),
605692
),
606693
)
694+
elif item.type == "image_generation_call" and item.result:
695+
if not ctx.generated_image_item:
696+
continue
697+
698+
url = await converter.base64_image_to_url(item.result)
699+
image = GeneratedImage(id=item.id, url=url)
700+
701+
ctx.generated_image_item.image = image
702+
yield ThreadItemDoneEvent(item=ctx.generated_image_item)
703+
704+
ctx.generated_image_item = None
705+
elif event.type == "response.image_generation_call.partial_image":
706+
if not ctx.generated_image_item:
707+
continue
708+
709+
url = await converter.base64_image_to_url(event.partial_image_b64)
710+
progress = converter.partial_image_index_to_progress(
711+
event.partial_image_index
712+
)
713+
714+
ctx.generated_image_item.image = GeneratedImage(
715+
id=event.item_id, url=url
716+
)
717+
718+
yield ThreadItemUpdatedEvent(
719+
item_id=ctx.generated_image_item.id,
720+
update=GeneratedImageUpdated(
721+
image=ctx.generated_image_item.image, progress=progress
722+
),
723+
)
607724

608725
except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
609726
for item_id in produced_items:
@@ -694,6 +811,17 @@ async def tag_to_message_content(
694811
"A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
695812
)
696813

814+
async def generated_image_to_input(
815+
self, item: GeneratedImageItem
816+
) -> TResponseInputItem | list[TResponseInputItem] | None:
817+
"""
818+
Convert a GeneratedImageItem into input item(s) to send to the model.
819+
Required when generated images are enabled.
820+
"""
821+
raise NotImplementedError(
822+
"A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented"
823+
)
824+
697825
async def hidden_context_to_input(
698826
self, item: HiddenContextItem
699827
) -> TResponseInputItem | list[TResponseInputItem] | None:
@@ -984,6 +1112,9 @@ async def _thread_item_to_input_item(
9841112
case SDKHiddenContextItem():
9851113
out = await self.sdk_hidden_context_to_input(item) or []
9861114
return out if isinstance(out, list) else [out]
1115+
case GeneratedImageItem():
1116+
out = await self.generated_image_to_input(item) or []
1117+
return out if isinstance(out, list) else [out]
9871118
case _:
9881119
assert_never(item)
9891120

0 commit comments

Comments
 (0)