Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
logger.info(f"Starting superstep {self._iteration + 1}")
yield WorkflowEvent.superstep_started(iteration=self._iteration + 1)

# Reset per-superstep tracking for HITL detection
self._ctx.reset_superstep_request_info_tracking()

# Run iteration concurrently with live event streaming: we poll
# for new events while the iteration coroutine progresses.
iteration_task = asyncio.create_task(self._run_iteration())
Expand Down Expand Up @@ -149,6 +152,18 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:

yield WorkflowEvent.superstep_completed(iteration=self._iteration)

# Check for HITL pause: if any request_info events were emitted
# during this superstep, pause the workflow even if there are pending
# messages. This prevents parallel nodes from continuing to run while
# HITL input is needed. Pending messages are preserved in memory and
# will be delivered alongside HITL responses in the next run.
if self._ctx.had_request_info_in_superstep():
logger.info(
f"Pausing workflow after superstep {self._iteration}: "
"request_info event(s) emitted during this superstep"
)
break

# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break
Expand Down
32 changes: 32 additions & 0 deletions python/packages/core/agent_framework/_workflows/_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,18 @@ async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]
"""
...

def reset_superstep_request_info_tracking(self) -> None:
"""Reset tracking of new request_info events for the current superstep."""
...

def had_request_info_in_superstep(self) -> bool:
"""Check if any request_info events were emitted during the current superstep.

Returns:
True if at least one request_info event was emitted since the last reset.
"""
...


class InProcRunnerContext:
"""In-process execution context for local execution and optional checkpointing."""
Expand All @@ -306,6 +318,9 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None):
# Streaming flag - set by workflow's run(..., stream=True) vs run(..., stream=False)
self._streaming: bool = False

# Track whether new request_info events were emitted during the current superstep
self._new_request_info_in_superstep: bool = False

# region Messaging and Events
async def send_message(self, message: Message) -> None:
self._messages.setdefault(message.source_id, [])
Expand Down Expand Up @@ -415,6 +430,7 @@ def reset_for_new_run(self) -> None:
# Clear any pending events (best-effort) by recreating the queue
self._event_queue = asyncio.Queue()
self._streaming = False # Reset streaming flag
self._new_request_info_in_superstep = False

async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
"""Apply a checkpoint to the current context, mutating its state."""
Expand All @@ -435,6 +451,9 @@ async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
# Restore workflow ID
self._workflow_id = checkpoint.workflow_id

# Reset superstep tracking - restored events are pre-existing, not new
self._new_request_info_in_superstep = False

# endregion Checkpointing

def set_workflow_id(self, workflow_id: str) -> None:
Expand Down Expand Up @@ -481,6 +500,7 @@ async def add_request_info_event(self, event: WorkflowEvent[Any]) -> None:
if event.request_id is None:
raise ValueError("request_info event must have a request_id")
self._pending_request_info_events[event.request_id] = event
self._new_request_info_in_superstep = True
await self.add_event(event)

async def send_request_info_response(self, request_id: str, response: Any) -> None:
Expand Down Expand Up @@ -521,3 +541,15 @@ async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]
A dictionary mapping request IDs to their corresponding WorkflowEvent (type='request_info').
"""
return dict(self._pending_request_info_events)

def reset_superstep_request_info_tracking(self) -> None:
"""Reset tracking of new request_info events for the current superstep."""
self._new_request_info_in_superstep = False

def had_request_info_in_superstep(self) -> bool:
"""Check if any request_info events were emitted during the current superstep.

Returns:
True if at least one request_info event was emitted since the last reset.
"""
return self._new_request_info_in_superstep
Loading
Loading