Skip to content

Commit 876c31b

Browse files
committed
Add default implementation for generated_image_to_input
1 parent 800a66c commit 876c31b

File tree

2 files changed

+71
-11
lines changed

2 files changed

+71
-11
lines changed

chatkit/agents.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
EasyInputMessageParam,
2626
ResponseFunctionToolCallParam,
2727
ResponseInputContentParam,
28+
ResponseInputImageParam,
2829
ResponseInputMessageContentListParam,
2930
ResponseInputTextParam,
3031
ResponseOutputText,
@@ -377,7 +378,7 @@ class ResponseStreamConverter:
377378
assigned a progress value of 0.
378379
"""
379380

380-
def __init__(self, partial_images: int | None = None):
381+
def __init__(self, *, partial_images: int | None = None):
381382
"""
382383
Args:
383384
partial_images: The expected number of partial image updates for image
@@ -419,10 +420,10 @@ def partial_image_index_to_progress(self, partial_image_index: int) -> float:
419420
A float between 0 and 1 representing progress for the image
420421
generation result.
421422
"""
422-
if self.partial_images is None:
423-
return 0
423+
if self.partial_images is None or self.partial_images <= 0:
424+
return 0.0
424425

425-
return partial_image_index / self.partial_images
426+
return min(1.0, partial_image_index / self.partial_images)
426427

427428

428429
_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
@@ -708,7 +709,10 @@ def end_workflow(item: WorkflowItem):
708709
if not ctx.generated_image_item:
709710
continue
710711

711-
url = await converter.base64_image_to_url(item.id, item.result)
712+
url = await converter.base64_image_to_url(
713+
image_id=item.id,
714+
base64_image=item.result,
715+
)
712716
image = GeneratedImage(id=item.id, url=url)
713717

714718
ctx.generated_image_item.image = image
@@ -720,9 +724,9 @@ def end_workflow(item: WorkflowItem):
720724
continue
721725

722726
url = await converter.base64_image_to_url(
723-
event.item_id,
724-
event.partial_image_b64,
725-
event.partial_image_index,
727+
image_id=event.item_id,
728+
base64_image=event.partial_image_b64,
729+
partial_image_index=event.partial_image_index,
726730
)
727731
progress = converter.partial_image_index_to_progress(
728732
event.partial_image_index
@@ -833,10 +837,27 @@ async def generated_image_to_input(
833837
) -> TResponseInputItem | list[TResponseInputItem] | None:
834838
"""
835839
Convert a GeneratedImageItem into input item(s) to send to the model.
836-
Required when generated images are enabled.
840+
Override this method to customize the conversion of generated images, such as when your
841+
generated image url is not publicly reachable.
837842
"""
838-
raise NotImplementedError(
839-
"A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented"
843+
if not item.image:
844+
return None
845+
846+
return Message(
847+
type="message",
848+
content=[
849+
ResponseInputTextParam(
850+
type="input_text",
851+
text="The following image was generated by the agent.",
852+
),
853+
ResponseInputImageParam(
854+
type="input_image",
855+
detail="auto",
856+
file_id=item.image.id,
857+
image_url=item.image.url,
858+
),
859+
],
860+
role="user",
840861
)
841862

842863
async def hidden_context_to_input(

tests/test_agents.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,45 @@ async def test_input_item_converter_user_input_with_tags_throws_by_default():
548548
await simple_to_agent_input(items)
549549

550550

551+
async def test_input_item_converter_generated_image_item():
552+
items = [
553+
GeneratedImageItem(
554+
id="img_item_1",
555+
thread_id=thread.id,
556+
created_at=datetime.now(),
557+
image=GeneratedImage(id="img_1", url="https://example.com/img.png"),
558+
)
559+
]
560+
561+
input_items = await simple_to_agent_input(items)
562+
assert len(input_items) == 1
563+
564+
message = cast(dict, input_items[0])
565+
assert message.get("type") == "message"
566+
assert message.get("role") == "user"
567+
568+
content = cast(list, message.get("content"))
569+
assert content[0].get("type") == "input_text"
570+
assert content[0].get("text") == "The following image was generated by the agent."
571+
assert content[1].get("type") == "input_image"
572+
assert content[1].get("file_id") == "img_1"
573+
assert content[1].get("image_url") == "https://example.com/img.png"
574+
assert content[1].get("detail") == "auto"
575+
576+
577+
async def test_input_item_converter_generated_image_item_without_image():
578+
items = [
579+
GeneratedImageItem(
580+
id="img_item_1",
581+
thread_id=thread.id,
582+
created_at=datetime.now(),
583+
)
584+
]
585+
586+
input_items = await simple_to_agent_input(items)
587+
assert input_items == []
588+
589+
551590
async def test_input_item_converter_for_hidden_context_with_string_content():
552591
items = [
553592
HiddenContextItem(

0 commit comments

Comments
 (0)