Skip to content

Commit 51e0454

Browse files
committed
(3) Add tests
1 parent 9100b7a commit 51e0454

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

tests/test_agents.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
from chatkit.agents import (
6161
AgentContext,
62+
ResponseStreamConverter,
6263
ThreadItemConverter,
6364
accumulate_text,
6465
simple_to_agent_input,
@@ -78,6 +79,9 @@
7879
CustomTask,
7980
DurationSummary,
8081
FileSource,
82+
GeneratedImage,
83+
GeneratedImageItem,
84+
GeneratedImageUpdated,
8185
HiddenContextItem,
8286
InferenceOptions,
8387
Page,
@@ -1191,6 +1195,187 @@ async def test_stream_agent_response_assistant_message_content_types():
11911195
assert message.id == "1"
11921196

11931197

1198+
async def test_stream_agent_response_image_generation_events():
1199+
context = AgentContext(
1200+
previous_response_id=None, thread=thread, store=mock_store, request_context=None
1201+
)
1202+
result = make_result()
1203+
1204+
result.add_event(
1205+
RawResponsesStreamEvent(
1206+
type="raw_response_event",
1207+
data=Mock(
1208+
type="response.output_item.added",
1209+
item=Mock(type="image_generation_call", id="img_call_1"),
1210+
output_index=0,
1211+
sequence_number=0,
1212+
),
1213+
)
1214+
)
1215+
result.add_event(
1216+
RawResponsesStreamEvent(
1217+
type="raw_response_event",
1218+
data=Mock(
1219+
type="response.output_item.done",
1220+
item=Mock(
1221+
type="image_generation_call", id="img_call_1", result="dGVzdA=="
1222+
),
1223+
output_index=0,
1224+
sequence_number=1,
1225+
),
1226+
)
1227+
)
1228+
result.done()
1229+
1230+
stream = stream_agent_response(context, result)
1231+
event1 = await stream.__anext__()
1232+
assert isinstance(event1, ThreadItemAddedEvent)
1233+
assert isinstance(event1.item, GeneratedImageItem)
1234+
assert event1.item.type == "generated_image"
1235+
assert event1.item.id == "message_id"
1236+
assert event1.item.image is None
1237+
1238+
event2 = await stream.__anext__()
1239+
assert isinstance(event2, ThreadItemDoneEvent)
1240+
assert isinstance(event2.item, GeneratedImageItem)
1241+
assert event2.item.id == event1.item.id
1242+
assert event2.item.image == GeneratedImage(
1243+
id="img_call_1", url="data:image/png;base64,dGVzdA=="
1244+
)
1245+
1246+
with pytest.raises(StopAsyncIteration):
1247+
await stream.__anext__()
1248+
1249+
1250+
async def test_stream_agent_response_image_generation_events_with_custom_converter():
1251+
context = AgentContext(
1252+
previous_response_id=None, thread=thread, store=mock_store, request_context=None
1253+
)
1254+
result = make_result()
1255+
1256+
result.add_event(
1257+
RawResponsesStreamEvent(
1258+
type="raw_response_event",
1259+
data=Mock(
1260+
type="response.output_item.added",
1261+
item=Mock(type="image_generation_call", id="img_call_1"),
1262+
output_index=0,
1263+
sequence_number=0,
1264+
),
1265+
)
1266+
)
1267+
result.add_event(
1268+
RawResponsesStreamEvent(
1269+
type="raw_response_event",
1270+
data=Mock(
1271+
type="response.output_item.done",
1272+
item=Mock(
1273+
type="image_generation_call", id="img_call_1", result="dGVzdA=="
1274+
),
1275+
output_index=0,
1276+
sequence_number=1,
1277+
),
1278+
)
1279+
)
1280+
result.done()
1281+
1282+
class CustomResponseStreamConverter(ResponseStreamConverter):
1283+
def __init__(self):
1284+
super().__init__()
1285+
self.calls: list[str] = []
1286+
1287+
async def base64_image_to_url(self, base64_image: str) -> str:
1288+
self.calls.append(base64_image)
1289+
return f"https://example.com/{base64_image}"
1290+
1291+
converter = CustomResponseStreamConverter()
1292+
stream = stream_agent_response(context, result, converter=converter)
1293+
event1 = await stream.__anext__()
1294+
assert isinstance(event1, ThreadItemAddedEvent)
1295+
assert isinstance(event1.item, GeneratedImageItem)
1296+
assert event1.item.image is None
1297+
1298+
event2 = await stream.__anext__()
1299+
assert isinstance(event2, ThreadItemDoneEvent)
1300+
assert isinstance(event2.item, GeneratedImageItem)
1301+
assert converter.calls == ["dGVzdA=="]
1302+
assert event2.item.image == GeneratedImage(
1303+
id="img_call_1", url="https://example.com/dGVzdA=="
1304+
)
1305+
with pytest.raises(StopAsyncIteration):
1306+
await stream.__anext__()
1307+
1308+
1309+
async def test_stream_agent_response_image_generation_partial_progress():
1310+
context = AgentContext(
1311+
previous_response_id=None, thread=thread, store=mock_store, request_context=None
1312+
)
1313+
result = make_result()
1314+
1315+
result.add_event(
1316+
RawResponsesStreamEvent(
1317+
type="raw_response_event",
1318+
data=Mock(
1319+
type="response.output_item.added",
1320+
item=Mock(type="image_generation_call", id="img_call_1"),
1321+
output_index=0,
1322+
sequence_number=0,
1323+
),
1324+
)
1325+
)
1326+
result.add_event(
1327+
RawResponsesStreamEvent(
1328+
type="raw_response_event",
1329+
data=Mock(
1330+
type="response.image_generation_call.partial_image",
1331+
partial_image_b64="dGVzdA==",
1332+
partial_image_index=1,
1333+
item_id="img_call_1",
1334+
output_index=0,
1335+
sequence_number=1,
1336+
),
1337+
)
1338+
)
1339+
result.add_event(
1340+
RawResponsesStreamEvent(
1341+
type="raw_response_event",
1342+
data=Mock(
1343+
type="response.output_item.done",
1344+
item=Mock(
1345+
type="image_generation_call", id="img_call_1", result="dGVzdA=="
1346+
),
1347+
output_index=0,
1348+
sequence_number=2,
1349+
),
1350+
)
1351+
)
1352+
result.done()
1353+
1354+
converter = ResponseStreamConverter(partial_images=3)
1355+
events = await all_events(
1356+
stream_agent_response(context, result, converter=converter)
1357+
)
1358+
1359+
assert len(events) == 3
1360+
added_event, partial_event, done_event = events
1361+
1362+
assert isinstance(added_event, ThreadItemAddedEvent)
1363+
assert isinstance(added_event.item, GeneratedImageItem)
1364+
1365+
assert isinstance(partial_event, ThreadItemUpdatedEvent)
1366+
assert isinstance(partial_event.update, GeneratedImageUpdated)
1367+
assert partial_event.update.progress == pytest.approx(1 / 3)
1368+
assert partial_event.update.image == GeneratedImage(
1369+
id="img_call_1", url="data:image/png;base64,dGVzdA=="
1370+
)
1371+
1372+
assert isinstance(done_event, ThreadItemDoneEvent)
1373+
assert isinstance(done_event.item, GeneratedImageItem)
1374+
assert done_event.item.image == GeneratedImage(
1375+
id="img_call_1", url="data:image/png;base64,dGVzdA=="
1376+
)
1377+
1378+
11941379
async def test_workflow_streams_first_thought():
11951380
context = AgentContext(
11961381
previous_response_id=None, thread=thread, store=mock_store, request_context=None

0 commit comments

Comments
 (0)