Skip to content

Commit 800a66c

Browse files
committed
update base64_image_to_url to take image id and partial image index as input to url generation
1 parent 51e0454 commit 800a66c

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

chatkit/agents.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,25 @@ def __init__(self, partial_images: int | None = None):
386386
"""
387387
self.partial_images = partial_images
388388

389-
async def base64_image_to_url(self, base64_image: str) -> str:
389+
async def base64_image_to_url(
390+
self,
391+
image_id: str,
392+
base64_image: str,
393+
partial_image_index: int | None = None,
394+
) -> str:
390395
"""
391396
Convert a base64-encoded image into a URL.
392397
393398
This method is used to produce the URL stored on thread items for image
394399
generation results.
400+
401+
Args:
402+
image_id: The ID of the image generation call. This stays stable across partial image updates.
403+
base64_image: The base64-encoded image.
404+
partial_image_index: The index of the partial image update, starting from 0.
405+
406+
Returns:
407+
A URL string.
395408
"""
396409
return f"data:image/png;base64,{base64_image}"
397410

@@ -695,7 +708,7 @@ def end_workflow(item: WorkflowItem):
695708
if not ctx.generated_image_item:
696709
continue
697710

698-
url = await converter.base64_image_to_url(item.result)
711+
url = await converter.base64_image_to_url(item.id, item.result)
699712
image = GeneratedImage(id=item.id, url=url)
700713

701714
ctx.generated_image_item.image = image
@@ -706,7 +719,11 @@ def end_workflow(item: WorkflowItem):
706719
if not ctx.generated_image_item:
707720
continue
708721

709-
url = await converter.base64_image_to_url(event.partial_image_b64)
722+
url = await converter.base64_image_to_url(
723+
event.item_id,
724+
event.partial_image_b64,
725+
event.partial_image_index,
726+
)
710727
progress = converter.partial_image_index_to_progress(
711728
event.partial_image_index
712729
)

tests/test_agents.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,11 +1282,16 @@ async def test_stream_agent_response_image_generation_events_with_custom_convert
12821282
class CustomResponseStreamConverter(ResponseStreamConverter):
12831283
def __init__(self):
12841284
super().__init__()
1285-
self.calls: list[str] = []
1285+
self.calls: list[tuple[str, str, int | None]] = []
12861286

1287-
async def base64_image_to_url(self, base64_image: str) -> str:
1288-
self.calls.append(base64_image)
1289-
return f"https://example.com/{base64_image}"
1287+
async def base64_image_to_url(
1288+
self,
1289+
image_id: str,
1290+
base64_image: str,
1291+
partial_image_index: int | None = None,
1292+
) -> str:
1293+
self.calls.append((image_id, base64_image, partial_image_index))
1294+
return f"https://example.com/{image_id}"
12901295

12911296
converter = CustomResponseStreamConverter()
12921297
stream = stream_agent_response(context, result, converter=converter)
@@ -1298,9 +1303,9 @@ async def base64_image_to_url(self, base64_image: str) -> str:
12981303
event2 = await stream.__anext__()
12991304
assert isinstance(event2, ThreadItemDoneEvent)
13001305
assert isinstance(event2.item, GeneratedImageItem)
1301-
assert converter.calls == ["dGVzdA=="]
1306+
assert converter.calls == [("img_call_1", "dGVzdA==", None)]
13021307
assert event2.item.image == GeneratedImage(
1303-
id="img_call_1", url="https://example.com/dGVzdA=="
1308+
id="img_call_1", url="https://example.com/img_call_1"
13041309
)
13051310
with pytest.raises(StopAsyncIteration):
13061311
await stream.__anext__()

0 commit comments

Comments
 (0)