|
59 | 59 |
|
60 | 60 | from chatkit.agents import ( |
61 | 61 | AgentContext, |
| 62 | + ResponseStreamConverter, |
62 | 63 | ThreadItemConverter, |
63 | 64 | accumulate_text, |
64 | 65 | simple_to_agent_input, |
|
78 | 79 | CustomTask, |
79 | 80 | DurationSummary, |
80 | 81 | FileSource, |
| 82 | + GeneratedImage, |
| 83 | + GeneratedImageItem, |
| 84 | + GeneratedImageUpdated, |
81 | 85 | HiddenContextItem, |
82 | 86 | InferenceOptions, |
83 | 87 | Page, |
@@ -1191,6 +1195,187 @@ async def test_stream_agent_response_assistant_message_content_types(): |
1191 | 1195 | assert message.id == "1" |
1192 | 1196 |
|
1193 | 1197 |
|
| 1198 | +async def test_stream_agent_response_image_generation_events(): |
| 1199 | + context = AgentContext( |
| 1200 | + previous_response_id=None, thread=thread, store=mock_store, request_context=None |
| 1201 | + ) |
| 1202 | + result = make_result() |
| 1203 | + |
| 1204 | + result.add_event( |
| 1205 | + RawResponsesStreamEvent( |
| 1206 | + type="raw_response_event", |
| 1207 | + data=Mock( |
| 1208 | + type="response.output_item.added", |
| 1209 | + item=Mock(type="image_generation_call", id="img_call_1"), |
| 1210 | + output_index=0, |
| 1211 | + sequence_number=0, |
| 1212 | + ), |
| 1213 | + ) |
| 1214 | + ) |
| 1215 | + result.add_event( |
| 1216 | + RawResponsesStreamEvent( |
| 1217 | + type="raw_response_event", |
| 1218 | + data=Mock( |
| 1219 | + type="response.output_item.done", |
| 1220 | + item=Mock( |
| 1221 | + type="image_generation_call", id="img_call_1", result="dGVzdA==" |
| 1222 | + ), |
| 1223 | + output_index=0, |
| 1224 | + sequence_number=1, |
| 1225 | + ), |
| 1226 | + ) |
| 1227 | + ) |
| 1228 | + result.done() |
| 1229 | + |
| 1230 | + stream = stream_agent_response(context, result) |
| 1231 | + event1 = await stream.__anext__() |
| 1232 | + assert isinstance(event1, ThreadItemAddedEvent) |
| 1233 | + assert isinstance(event1.item, GeneratedImageItem) |
| 1234 | + assert event1.item.type == "generated_image" |
| 1235 | + assert event1.item.id == "message_id" |
| 1236 | + assert event1.item.image is None |
| 1237 | + |
| 1238 | + event2 = await stream.__anext__() |
| 1239 | + assert isinstance(event2, ThreadItemDoneEvent) |
| 1240 | + assert isinstance(event2.item, GeneratedImageItem) |
| 1241 | + assert event2.item.id == event1.item.id |
| 1242 | + assert event2.item.image == GeneratedImage( |
| 1243 | + id="img_call_1", url="" |
| 1244 | + ) |
| 1245 | + |
| 1246 | + with pytest.raises(StopAsyncIteration): |
| 1247 | + await stream.__anext__() |
| 1248 | + |
| 1249 | + |
| 1250 | +async def test_stream_agent_response_image_generation_events_with_custom_converter(): |
| 1251 | + context = AgentContext( |
| 1252 | + previous_response_id=None, thread=thread, store=mock_store, request_context=None |
| 1253 | + ) |
| 1254 | + result = make_result() |
| 1255 | + |
| 1256 | + result.add_event( |
| 1257 | + RawResponsesStreamEvent( |
| 1258 | + type="raw_response_event", |
| 1259 | + data=Mock( |
| 1260 | + type="response.output_item.added", |
| 1261 | + item=Mock(type="image_generation_call", id="img_call_1"), |
| 1262 | + output_index=0, |
| 1263 | + sequence_number=0, |
| 1264 | + ), |
| 1265 | + ) |
| 1266 | + ) |
| 1267 | + result.add_event( |
| 1268 | + RawResponsesStreamEvent( |
| 1269 | + type="raw_response_event", |
| 1270 | + data=Mock( |
| 1271 | + type="response.output_item.done", |
| 1272 | + item=Mock( |
| 1273 | + type="image_generation_call", id="img_call_1", result="dGVzdA==" |
| 1274 | + ), |
| 1275 | + output_index=0, |
| 1276 | + sequence_number=1, |
| 1277 | + ), |
| 1278 | + ) |
| 1279 | + ) |
| 1280 | + result.done() |
| 1281 | + |
| 1282 | + class CustomResponseStreamConverter(ResponseStreamConverter): |
| 1283 | + def __init__(self): |
| 1284 | + super().__init__() |
| 1285 | + self.calls: list[str] = [] |
| 1286 | + |
| 1287 | + async def base64_image_to_url(self, base64_image: str) -> str: |
| 1288 | + self.calls.append(base64_image) |
| 1289 | + return f"https://example.com/{base64_image}" |
| 1290 | + |
| 1291 | + converter = CustomResponseStreamConverter() |
| 1292 | + stream = stream_agent_response(context, result, converter=converter) |
| 1293 | + event1 = await stream.__anext__() |
| 1294 | + assert isinstance(event1, ThreadItemAddedEvent) |
| 1295 | + assert isinstance(event1.item, GeneratedImageItem) |
| 1296 | + assert event1.item.image is None |
| 1297 | + |
| 1298 | + event2 = await stream.__anext__() |
| 1299 | + assert isinstance(event2, ThreadItemDoneEvent) |
| 1300 | + assert isinstance(event2.item, GeneratedImageItem) |
| 1301 | + assert converter.calls == ["dGVzdA=="] |
| 1302 | + assert event2.item.image == GeneratedImage( |
| 1303 | + id="img_call_1", url="https://example.com/dGVzdA==" |
| 1304 | + ) |
| 1305 | + with pytest.raises(StopAsyncIteration): |
| 1306 | + await stream.__anext__() |
| 1307 | + |
| 1308 | + |
| 1309 | +async def test_stream_agent_response_image_generation_partial_progress(): |
| 1310 | + context = AgentContext( |
| 1311 | + previous_response_id=None, thread=thread, store=mock_store, request_context=None |
| 1312 | + ) |
| 1313 | + result = make_result() |
| 1314 | + |
| 1315 | + result.add_event( |
| 1316 | + RawResponsesStreamEvent( |
| 1317 | + type="raw_response_event", |
| 1318 | + data=Mock( |
| 1319 | + type="response.output_item.added", |
| 1320 | + item=Mock(type="image_generation_call", id="img_call_1"), |
| 1321 | + output_index=0, |
| 1322 | + sequence_number=0, |
| 1323 | + ), |
| 1324 | + ) |
| 1325 | + ) |
| 1326 | + result.add_event( |
| 1327 | + RawResponsesStreamEvent( |
| 1328 | + type="raw_response_event", |
| 1329 | + data=Mock( |
| 1330 | + type="response.image_generation_call.partial_image", |
| 1331 | + partial_image_b64="dGVzdA==", |
| 1332 | + partial_image_index=1, |
| 1333 | + item_id="img_call_1", |
| 1334 | + output_index=0, |
| 1335 | + sequence_number=1, |
| 1336 | + ), |
| 1337 | + ) |
| 1338 | + ) |
| 1339 | + result.add_event( |
| 1340 | + RawResponsesStreamEvent( |
| 1341 | + type="raw_response_event", |
| 1342 | + data=Mock( |
| 1343 | + type="response.output_item.done", |
| 1344 | + item=Mock( |
| 1345 | + type="image_generation_call", id="img_call_1", result="dGVzdA==" |
| 1346 | + ), |
| 1347 | + output_index=0, |
| 1348 | + sequence_number=2, |
| 1349 | + ), |
| 1350 | + ) |
| 1351 | + ) |
| 1352 | + result.done() |
| 1353 | + |
| 1354 | + converter = ResponseStreamConverter(partial_images=3) |
| 1355 | + events = await all_events( |
| 1356 | + stream_agent_response(context, result, converter=converter) |
| 1357 | + ) |
| 1358 | + |
| 1359 | + assert len(events) == 3 |
| 1360 | + added_event, partial_event, done_event = events |
| 1361 | + |
| 1362 | + assert isinstance(added_event, ThreadItemAddedEvent) |
| 1363 | + assert isinstance(added_event.item, GeneratedImageItem) |
| 1364 | + |
| 1365 | + assert isinstance(partial_event, ThreadItemUpdatedEvent) |
| 1366 | + assert isinstance(partial_event.update, GeneratedImageUpdated) |
| 1367 | + assert partial_event.update.progress == pytest.approx(1 / 3) |
| 1368 | + assert partial_event.update.image == GeneratedImage( |
| 1369 | + id="img_call_1", url="" |
| 1370 | + ) |
| 1371 | + |
| 1372 | + assert isinstance(done_event, ThreadItemDoneEvent) |
| 1373 | + assert isinstance(done_event.item, GeneratedImageItem) |
| 1374 | + assert done_event.item.image == GeneratedImage( |
| 1375 | + id="img_call_1", url="" |
| 1376 | + ) |
| 1377 | + |
| 1378 | + |
1194 | 1379 | async def test_workflow_streams_first_thought(): |
1195 | 1380 | context = AgentContext( |
1196 | 1381 | previous_response_id=None, thread=thread, store=mock_store, request_context=None |
|
0 commit comments