Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
DurationSummary,
EndOfTurnItem,
FileSource,
GeneratedImage,
GeneratedImageItem,
GeneratedImageUpdated,
HiddenContextItem,
SDKHiddenContextItem,
Task,
Expand Down Expand Up @@ -105,6 +108,7 @@ class AgentContext(BaseModel, Generic[TContext]):
previous_response_id: str | None = None
client_tool_call: ClientToolCall | None = None
workflow_item: WorkflowItem | None = None
generated_image_item: GeneratedImageItem | None = None
_events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()

def generate_id(
Expand Down Expand Up @@ -356,10 +360,97 @@ class StreamingThoughtTracker(BaseModel):
task: ThoughtTask


class ResponseStreamConverter:
"""Used by `stream_agent_response` to convert streamed Agents SDK output
into values used by ChatKit thread items and thread stream events.

Defines overridable methods for adapting streamed data (such as image
generation results and partial updates) into the forms expected by ChatKit.
"""

partial_images: int | None = None
"""
The expected number of partial image updates for an image generation result.

When set, this value is used to normalize partial image indices into a
progress value in the range [0, 1]. If unset, all partial image updates are
assigned a progress value of 0.
"""

def __init__(self, partial_images: int | None = None):
"""
Args:
partial_images: The expected number of partial image updates for image
generation results, or None if no progress normalization should
be performed.
"""
self.partial_images = partial_images

async def base64_image_to_url(
self,
image_id: str,
base64_image: str,
partial_image_index: int | None = None,
) -> str:
"""
Convert a base64-encoded image into a URL.

This method is used to produce the URL stored on thread items for image
generation results.

Args:
image_id: The ID of the image generation call. This stays stable across partial image updates.
base64_image: The base64-encoded image.
partial_image_index: The index of the partial image update, starting from 0.

Returns:
A URL string.
"""
return f"data:image/png;base64,{base64_image}"

def partial_image_index_to_progress(self, partial_image_index: int) -> float:
"""
Convert a partial image index into a normalized progress value.

Args:
partial_image_index: The index of the partial image update, starting from 0.

Returns:
A float between 0 and 1 representing progress for the image
generation result.
"""
if self.partial_images is None:
return 0

return partial_image_index / self.partial_images


_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()


async def stream_agent_response(
context: AgentContext, result: RunResultStreaming
context: AgentContext,
result: RunResultStreaming,
*,
converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
) -> AsyncIterator[ThreadStreamEvent]:
"""Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents."""
"""
Convert a streamed Agents SDK run into ChatKit thread stream events.

This function consumes a streaming run result and yields `ThreadStreamEvent`
objects as the run progresses.

Args:
context: The AgentContext to use for the stream.
result: The RunResultStreaming to convert.
image_generation_stream_converter: Controls how streamed image generation output
is converted into URLs and progress updates. The default converter stores the
generated base64 image and assigns a progress value of 0 to all partial image
updates.

Returns:
An async iterator that yields thread stream events representing the run result.
"""
current_item_id = None
current_tool_call = None
ctx = context
Expand Down Expand Up @@ -527,6 +618,15 @@ def end_workflow(item: WorkflowItem):
created_at=datetime.now(),
),
)
elif item.type == "image_generation_call":
ctx.generated_image_item = GeneratedImageItem(
id=ctx.generate_id("message"),
thread_id=thread.id,
created_at=datetime.now(),
image=None,
)
produced_items.add(ctx.generated_image_item.id)
yield ThreadItemAddedEvent(item=ctx.generated_image_item)
elif event.type == "response.reasoning_summary_text.delta":
if not ctx.workflow_item:
continue
Expand Down Expand Up @@ -604,6 +704,40 @@ def end_workflow(item: WorkflowItem):
created_at=datetime.now(),
),
)
elif item.type == "image_generation_call" and item.result:
if not ctx.generated_image_item:
continue

url = await converter.base64_image_to_url(item.id, item.result)
image = GeneratedImage(id=item.id, url=url)

ctx.generated_image_item.image = image
yield ThreadItemDoneEvent(item=ctx.generated_image_item)

ctx.generated_image_item = None
elif event.type == "response.image_generation_call.partial_image":
if not ctx.generated_image_item:
continue

url = await converter.base64_image_to_url(
event.item_id,
event.partial_image_b64,
event.partial_image_index,
)
progress = converter.partial_image_index_to_progress(
event.partial_image_index
)

ctx.generated_image_item.image = GeneratedImage(
id=event.item_id, url=url
)

yield ThreadItemUpdatedEvent(
item_id=ctx.generated_image_item.id,
update=GeneratedImageUpdated(
image=ctx.generated_image_item.image, progress=progress
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for readability should this block be above the done handlers?


except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
for item_id in produced_items:
Expand Down Expand Up @@ -694,6 +828,17 @@ async def tag_to_message_content(
"A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
)

async def generated_image_to_input(
self, item: GeneratedImageItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
"""
Convert a GeneratedImageItem into input item(s) to send to the model.
Required when generated images are enabled.
"""
raise NotImplementedError(
"A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented"
)
Comment on lines 838 to 860
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm is there no way for us to represent images in a response? I'm not actually sure what the right thing to do here is, but can you just add image content by URL which we have?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll add a default that returns a ResponseInputImageParam


async def hidden_context_to_input(
self, item: HiddenContextItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
Expand Down Expand Up @@ -984,6 +1129,9 @@ async def _thread_item_to_input_item(
case SDKHiddenContextItem():
out = await self.sdk_hidden_context_to_input(item) or []
return out if isinstance(out, list) else [out]
case GeneratedImageItem():
out = await self.generated_image_to_input(item) or []
return out if isinstance(out, list) else [out]
case _:
assert_never(item)

Expand Down
24 changes: 24 additions & 0 deletions chatkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ class WorkflowTaskUpdated(BaseModel):
task: Task


class GeneratedImageUpdated(BaseModel):
"""Event emitted when a generated image is updated."""

type: Literal["generated_image.updated"] = "generated_image.updated"
image: GeneratedImage
progress: float | None = None


ThreadItemUpdate = (
AssistantMessageContentPartAdded
| AssistantMessageContentPartTextDelta
Expand All @@ -481,6 +489,7 @@ class WorkflowTaskUpdated(BaseModel):
| WidgetRootUpdated
| WorkflowTaskAdded
| WorkflowTaskUpdated
| GeneratedImageUpdated
)
"""Union of possible updates applied to thread items."""

Expand Down Expand Up @@ -579,6 +588,20 @@ class WidgetItem(ThreadItemBase):
copy_text: str | None = None


class GeneratedImage(BaseModel):
"""Generated image."""

id: str
url: str


class GeneratedImageItem(ThreadItemBase):
"""Thread item containing a generated image."""

type: Literal["generated_image"] = "generated_image"
image: GeneratedImage | None = None


class TaskItem(ThreadItemBase):
"""Thread item containing a task."""

Expand Down Expand Up @@ -624,6 +647,7 @@ class SDKHiddenContextItem(ThreadItemBase):
| AssistantMessageItem
| ClientToolCallItem
| WidgetItem
| GeneratedImageItem
| WorkflowItem
| TaskItem
| HiddenContextItem
Expand Down
Loading