Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
EdgeDuplicationError,
GraphConnectivityError,
TypeCompatibilityError,
ValidationTypeEnum,
ValidationType,
WorkflowValidationError,
validate_workflow_graph,
)
Expand Down Expand Up @@ -102,7 +102,7 @@
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
"TypeCompatibilityError",
"ValidationTypeEnum",
"ValidationType",
"Workflow",
"WorkflowAgent",
"WorkflowBuilder",
Expand Down
36 changes: 18 additions & 18 deletions python/packages/core/agent_framework/_workflows/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import types
from collections import defaultdict
from collections.abc import Sequence
from enum import Enum
from typing import Any
from typing import Any, Literal

from ._edge import Edge, EdgeGroup, FanInEdgeGroup, InternalEdgeGroup
from ._executor import Executor
Expand All @@ -15,27 +14,28 @@


# region Enums and Base Classes
class ValidationTypeEnum(Enum):
"""Enumeration of workflow validation types."""

EDGE_DUPLICATION = "EDGE_DUPLICATION"
EXECUTOR_DUPLICATION = "EXECUTOR_DUPLICATION"
TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY"
GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY"
HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION"
OUTPUT_VALIDATION = "OUTPUT_VALIDATION"
ValidationType = Literal[
"edge_duplication",
"executor_duplication",
"type_compatibility",
"graph_connectivity",
"handler_output_annotation",
"output_validation",
"checkpoint_configuration",
]


class WorkflowValidationError(Exception):
"""Base exception for workflow validation errors."""

def __init__(self, message: str, validation_type: ValidationTypeEnum):
def __init__(self, message: str, type: ValidationType):
super().__init__(message)
self.message = message
self.validation_type = validation_type
self.type: ValidationType = type

def __str__(self) -> str:
return f"[{self.validation_type.value}] {self.message}"
return f"[{self.type}] {self.message}"


class EdgeDuplicationError(WorkflowValidationError):
Expand All @@ -44,7 +44,7 @@ class EdgeDuplicationError(WorkflowValidationError):
def __init__(self, edge_id: str):
super().__init__(
message=f"Duplicate edge detected: {edge_id}. Each edge in the workflow must be unique.",
validation_type=ValidationTypeEnum.EDGE_DUPLICATION,
type="edge_duplication",
)
self.edge_id = edge_id

Expand All @@ -64,7 +64,7 @@ def __init__(
message=f"Type incompatibility between executors '{source_executor_id}' -> '{target_executor_id}'. "
f"Source executor outputs types {[str(t) for t in source_types]} but target executor "
f"can only handle types {[str(t) for t in target_types]}.",
validation_type=ValidationTypeEnum.TYPE_COMPATIBILITY,
type="type_compatibility",
)
self.source_executor_id = source_executor_id
self.target_executor_id = target_executor_id
Expand All @@ -76,7 +76,7 @@ class GraphConnectivityError(WorkflowValidationError):
"""Exception raised when graph connectivity issues are detected."""

def __init__(self, message: str):
super().__init__(message, validation_type=ValidationTypeEnum.GRAPH_CONNECTIVITY)
super().__init__(message, type="graph_connectivity")


# endregion
Expand Down Expand Up @@ -361,14 +361,14 @@ def _output_validation(self, output_executors: list[str]) -> None:
if output_id not in self._executors:
raise WorkflowValidationError(
f"Output executor '{output_id}' is not present in the workflow graph",
validation_type=ValidationTypeEnum.OUTPUT_VALIDATION,
type="output_validation",
)

output_executor = self._executors[output_id]
if not output_executor.workflow_output_types:
raise WorkflowValidationError(
f"Output executor '{output_id}' must have output type annotations defined.",
validation_type=ValidationTypeEnum.OUTPUT_VALIDATION,
type="output_validation",
)

# endregion
Expand Down
25 changes: 25 additions & 0 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ def __init__(
self._graph_signature_hash = self._hash_graph_signature(self._graph_signature)
self._runner.graph_signature_hash = self._graph_signature_hash

@property
def has_checkpointing(self) -> bool:
"""Whether this workflow has checkpoint storage configured (build-time or runtime)."""
return self._runner_context.has_checkpointing()

def _ensure_not_running(self) -> None:
"""Ensure the workflow is not already running."""
if self._is_running:
Expand All @@ -238,6 +243,25 @@ def _reset_running_flag(self) -> None:
"""Reset the running flag."""
self._is_running = False

def _validate_sub_workflow_checkpointing(self) -> None:
"""Validate that all sub-workflows have checkpointing configured.

Raises:
WorkflowValidationError: If a sub-workflow is missing checkpoint configuration.
"""
from ._validation import WorkflowValidationError
from ._workflow_executor import WorkflowExecutor

for executor in self.executors.values():
if isinstance(executor, WorkflowExecutor) and not executor.workflow.has_checkpointing:
raise WorkflowValidationError(
f"Parent workflow has checkpointing enabled, but sub-workflow in executor "
f"'{executor.id}' does not. When checkpointing is enabled on a parent workflow, "
f"all sub-workflows must also have checkpoint_storage configured in their "
f"WorkflowBuilder.",
"checkpoint_configuration",
)

def to_dict(self) -> dict[str, Any]:
"""Serialize the workflow definition into a JSON-ready dictionary."""
data: dict[str, Any] = {
Expand Down Expand Up @@ -552,6 +576,7 @@ async def _run_core(
# Enable runtime checkpointing if storage provided
if checkpoint_storage is not None:
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)
self._validate_sub_workflow_checkpointing()

initial_executor_fn, reset_context = self._resolve_execution_mode(
message, responses, checkpoint_id, checkpoint_storage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from ._executor import Executor
from ._runner_context import InProcRunnerContext
from ._validation import validate_workflow_graph
from ._validation import WorkflowValidationError, validate_workflow_graph
from ._workflow import Workflow

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -1100,6 +1100,20 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None:
output_executors,
)

# Validate checkpoint configuration for sub-workflows
if self._checkpoint_storage is not None:
from ._workflow_executor import WorkflowExecutor

for executor in executors.values():
if isinstance(executor, WorkflowExecutor) and not executor.workflow.has_checkpointing:
raise WorkflowValidationError(
f"Parent workflow has checkpointing enabled, but sub-workflow in executor "
f"'{executor.id}' does not. When checkpointing is enabled on a parent workflow, "
f"all sub-workflows must also have checkpoint_storage configured in their "
f"WorkflowBuilder.",
"checkpoint_configuration",
)

# Add validation completed event
span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED)

Expand Down
146 changes: 146 additions & 0 deletions python/packages/core/tests/workflow/test_checkpoint_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Microsoft. All rights reserved.

import pytest
from typing_extensions import Never

from agent_framework import (
Executor,
WorkflowBuilder,
WorkflowContext,
WorkflowExecutor,
WorkflowValidationError,
handler,
)
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage


class StartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message, target_id="finish")


class FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(message)


def build_sub_workflow(checkpoint_storage: InMemoryCheckpointStorage | None = None) -> WorkflowExecutor:
sub_workflow = (
WorkflowBuilder(start_executor="start", checkpoint_storage=checkpoint_storage)
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(lambda: FinishExecutor(id="finish"), name="finish")
.add_edge("start", "finish")
.build()
)
return WorkflowExecutor(sub_workflow, id="sub")


def test_build_fails_when_parent_has_checkpoint_but_sub_does_not() -> None:
"""Parent has checkpoint_storage, sub-workflow does not -> error at build time."""
storage = InMemoryCheckpointStorage()

with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info:
WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor(
lambda: StartExecutor(id="start"), name="start"
).register_executor(build_sub_workflow, name="sub").add_edge("start", "sub").build()

assert exc_info.value.type == "checkpoint_configuration"


def test_build_succeeds_when_both_have_checkpoint() -> None:
"""Both parent and sub-workflow have checkpoint_storage -> no error."""
storage = InMemoryCheckpointStorage()

workflow = (
WorkflowBuilder(start_executor="start", checkpoint_storage=storage)
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(lambda: build_sub_workflow(checkpoint_storage=storage), name="sub")
.add_edge("start", "sub")
.build()
)
assert workflow is not None


def test_build_succeeds_when_neither_has_checkpoint() -> None:
"""Neither parent nor sub-workflow has checkpoint_storage -> no validation needed."""
workflow = (
WorkflowBuilder(start_executor="start")
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(build_sub_workflow, name="sub")
.add_edge("start", "sub")
.build()
)
assert workflow is not None


async def test_runtime_checkpoint_validates_sub_workflows() -> None:
"""Runtime checkpoint_storage on run() triggers validation of sub-workflows."""
storage = InMemoryCheckpointStorage()

# Build without checkpoint_storage on either - succeeds
workflow = (
WorkflowBuilder(start_executor="start")
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(build_sub_workflow, name="sub")
.add_edge("start", "sub")
.build()
)

# Run with runtime checkpoint_storage - should fail because sub has none
with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info:
await workflow.run("hello", checkpoint_storage=storage)

assert exc_info.value.type == "checkpoint_configuration"


def test_nested_sub_workflows_all_require_checkpoint() -> None:
"""A -> B -> C: if A has checkpoint, B must too, and B's build validates C."""
storage = InMemoryCheckpointStorage()

# Inner sub-workflow without checkpoint
inner_sub = build_sub_workflow()

# Middle workflow wrapping the inner sub - this should fail because
# middle has checkpoint but inner doesn't
with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info:
WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor(
lambda: StartExecutor(id="start"), name="start"
).register_executor(lambda: inner_sub, name="sub").add_edge("start", "sub").build()

assert exc_info.value.type == "checkpoint_configuration"


def test_error_message_identifies_executor() -> None:
"""Error message includes the executor ID of the offending sub-workflow."""
storage = InMemoryCheckpointStorage()
custom_id_sub = WorkflowExecutor(
WorkflowBuilder(start_executor="start")
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(lambda: FinishExecutor(id="finish"), name="finish")
.add_edge("start", "finish")
.build(),
id="my_custom_executor_name",
)

with pytest.raises(WorkflowValidationError, match="my_custom_executor_name"):
WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor(
lambda: StartExecutor(id="start"), name="start"
).register_executor(lambda: custom_id_sub, name="my_custom_executor_name").add_edge(
"start", "my_custom_executor_name"
).build()


def test_sub_workflow_without_checkpoint_parent_without_checkpoint_is_fine() -> None:
"""Sub-workflow has checkpoint but parent doesn't -> no error (sub manages its own checkpoints)."""
storage = InMemoryCheckpointStorage()

workflow = (
WorkflowBuilder(start_executor="start")
.register_executor(lambda: StartExecutor(id="start"), name="start")
.register_executor(lambda: build_sub_workflow(checkpoint_storage=storage), name="sub")
.add_edge("start", "sub")
.build()
)
assert workflow is not None
2 changes: 1 addition & 1 deletion python/packages/core/tests/workflow/test_sub_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> Workflow:
"""Build the main workflow with checkpointing for testing."""
two_step_executor = TwoStepSubWorkflowExecutor()
sub_workflow = WorkflowBuilder(start_executor=two_step_executor).build()
sub_workflow = WorkflowBuilder(start_executor=two_step_executor, checkpoint_storage=storage).build()
sub_workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow_executor")

coordinator = CheckpointTestCoordinator()
Expand Down
Loading
Loading