diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index c5666f7b26..4505b21cc1 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -59,7 +59,7 @@ EdgeDuplicationError, GraphConnectivityError, TypeCompatibilityError, - ValidationTypeEnum, + ValidationType, WorkflowValidationError, validate_workflow_graph, ) @@ -102,7 +102,7 @@ "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", "TypeCompatibilityError", - "ValidationTypeEnum", + "ValidationType", "Workflow", "WorkflowAgent", "WorkflowBuilder", diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index 6c08f60099..78b9dd4384 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -4,8 +4,7 @@ import types from collections import defaultdict from collections.abc import Sequence -from enum import Enum -from typing import Any +from typing import Any, Literal from ._edge import Edge, EdgeGroup, FanInEdgeGroup, InternalEdgeGroup from ._executor import Executor @@ -15,27 +14,28 @@ # region Enums and Base Classes -class ValidationTypeEnum(Enum): - """Enumeration of workflow validation types.""" - EDGE_DUPLICATION = "EDGE_DUPLICATION" - EXECUTOR_DUPLICATION = "EXECUTOR_DUPLICATION" - TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY" - GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY" - HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION" - OUTPUT_VALIDATION = "OUTPUT_VALIDATION" +ValidationType = Literal[ + "edge_duplication", + "executor_duplication", + "type_compatibility", + "graph_connectivity", + "handler_output_annotation", + "output_validation", + "checkpoint_configuration", +] class WorkflowValidationError(Exception): """Base exception for workflow validation errors.""" - def __init__(self, message: str, validation_type: ValidationTypeEnum): + def __init__(self, message: str, type: ValidationType): super().__init__(message) self.message = message - self.validation_type = validation_type + self.type: ValidationType = type def __str__(self) -> str: - return f"[{self.validation_type.value}] {self.message}" + return f"[{self.type}] {self.message}" class EdgeDuplicationError(WorkflowValidationError): @@ -44,7 +44,7 @@ class EdgeDuplicationError(WorkflowValidationError): def __init__(self, edge_id: str): super().__init__( message=f"Duplicate edge detected: {edge_id}. Each edge in the workflow must be unique.", - validation_type=ValidationTypeEnum.EDGE_DUPLICATION, + type="edge_duplication", ) self.edge_id = edge_id @@ -64,7 +64,7 @@ def __init__( message=f"Type incompatibility between executors '{source_executor_id}' -> '{target_executor_id}'. " f"Source executor outputs types {[str(t) for t in source_types]} but target executor " f"can only handle types {[str(t) for t in target_types]}.", - validation_type=ValidationTypeEnum.TYPE_COMPATIBILITY, + type="type_compatibility", ) self.source_executor_id = source_executor_id self.target_executor_id = target_executor_id @@ -76,7 +76,7 @@ class GraphConnectivityError(WorkflowValidationError): """Exception raised when graph connectivity issues are detected.""" def __init__(self, message: str): - super().__init__(message, validation_type=ValidationTypeEnum.GRAPH_CONNECTIVITY) + super().__init__(message, type="graph_connectivity") # endregion @@ -361,14 +361,14 @@ def _output_validation(self, output_executors: list[str]) -> None: if output_id not in self._executors: raise WorkflowValidationError( f"Output executor '{output_id}' is not present in the workflow graph", - validation_type=ValidationTypeEnum.OUTPUT_VALIDATION, + type="output_validation", ) output_executor = self._executors[output_id] if not output_executor.workflow_output_types: raise WorkflowValidationError( f"Output executor '{output_id}' must have output type annotations defined.", - validation_type=ValidationTypeEnum.OUTPUT_VALIDATION, + type="output_validation", ) # endregion diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 5f93644035..92db0f501a 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -228,6 +228,11 @@ def __init__( self._graph_signature_hash = self._hash_graph_signature(self._graph_signature) self._runner.graph_signature_hash = self._graph_signature_hash + @property + def has_checkpointing(self) -> bool: + """Whether this workflow has checkpoint storage configured (build-time or runtime).""" + return self._runner_context.has_checkpointing() + def _ensure_not_running(self) -> None: """Ensure the workflow is not already running.""" if self._is_running: @@ -238,6 +243,25 @@ def _reset_running_flag(self) -> None: """Reset the running flag.""" self._is_running = False + def _validate_sub_workflow_checkpointing(self) -> None: + """Validate that all sub-workflows have checkpointing configured. + + Raises: + WorkflowValidationError: If a sub-workflow is missing checkpoint configuration. + """ + from ._validation import WorkflowValidationError + from ._workflow_executor import WorkflowExecutor + + for executor in self.executors.values(): + if isinstance(executor, WorkflowExecutor) and not executor.workflow.has_checkpointing: + raise WorkflowValidationError( + f"Parent workflow has checkpointing enabled, but sub-workflow in executor " + f"'{executor.id}' does not. When checkpointing is enabled on a parent workflow, " + f"all sub-workflows must also have checkpoint_storage configured in their " + f"WorkflowBuilder.", + "checkpoint_configuration", + ) + def to_dict(self) -> dict[str, Any]: """Serialize the workflow definition into a JSON-ready dictionary.""" data: dict[str, Any] = { @@ -552,6 +576,7 @@ async def _run_core( # Enable runtime checkpointing if storage provided if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) + self._validate_sub_workflow_checkpointing() initial_executor_fn, reset_context = self._resolve_execution_mode( message, responses, checkpoint_id, checkpoint_storage diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 14fd512e17..735e47adfa 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -28,7 +28,7 @@ ) from ._executor import Executor from ._runner_context import InProcRunnerContext -from ._validation import validate_workflow_graph +from ._validation import WorkflowValidationError, validate_workflow_graph from ._workflow import Workflow if sys.version_info >= (3, 11): @@ -1100,6 +1100,20 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: output_executors, ) + # Validate checkpoint configuration for sub-workflows + if self._checkpoint_storage is not None: + from ._workflow_executor import WorkflowExecutor + + for executor in executors.values(): + if isinstance(executor, WorkflowExecutor) and not executor.workflow.has_checkpointing: + raise WorkflowValidationError( + f"Parent workflow has checkpointing enabled, but sub-workflow in executor " + f"'{executor.id}' does not. When checkpointing is enabled on a parent workflow, " + f"all sub-workflows must also have checkpoint_storage configured in their " + f"WorkflowBuilder.", + "checkpoint_configuration", + ) + # Add validation completed event span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED) diff --git a/python/packages/core/tests/workflow/test_checkpoint_configuration.py b/python/packages/core/tests/workflow/test_checkpoint_configuration.py new file mode 100644 index 0000000000..d2ccfef940 --- /dev/null +++ b/python/packages/core/tests/workflow/test_checkpoint_configuration.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft. All rights reserved. + +import pytest +from typing_extensions import Never + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowContext, + WorkflowExecutor, + WorkflowValidationError, + handler, +) +from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage + + +class StartExecutor(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message, target_id="finish") + + +class FinishExecutor(Executor): + @handler + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(message) + + +def build_sub_workflow(checkpoint_storage: InMemoryCheckpointStorage | None = None) -> WorkflowExecutor: + sub_workflow = ( + WorkflowBuilder(start_executor="start", checkpoint_storage=checkpoint_storage) + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(lambda: FinishExecutor(id="finish"), name="finish") + .add_edge("start", "finish") + .build() + ) + return WorkflowExecutor(sub_workflow, id="sub") + + +def test_build_fails_when_parent_has_checkpoint_but_sub_does_not() -> None: + """Parent has checkpoint_storage, sub-workflow does not -> error at build time.""" + storage = InMemoryCheckpointStorage() + + with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info: + WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor( + lambda: StartExecutor(id="start"), name="start" + ).register_executor(build_sub_workflow, name="sub").add_edge("start", "sub").build() + + assert exc_info.value.type == "checkpoint_configuration" + + +def test_build_succeeds_when_both_have_checkpoint() -> None: + """Both parent and sub-workflow have checkpoint_storage -> no error.""" + storage = InMemoryCheckpointStorage() + + workflow = ( + WorkflowBuilder(start_executor="start", checkpoint_storage=storage) + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(lambda: build_sub_workflow(checkpoint_storage=storage), name="sub") + .add_edge("start", "sub") + .build() + ) + assert workflow is not None + + +def test_build_succeeds_when_neither_has_checkpoint() -> None: + """Neither parent nor sub-workflow has checkpoint_storage -> no validation needed.""" + workflow = ( + WorkflowBuilder(start_executor="start") + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(build_sub_workflow, name="sub") + .add_edge("start", "sub") + .build() + ) + assert workflow is not None + + +async def test_runtime_checkpoint_validates_sub_workflows() -> None: + """Runtime checkpoint_storage on run() triggers validation of sub-workflows.""" + storage = InMemoryCheckpointStorage() + + # Build without checkpoint_storage on either - succeeds + workflow = ( + WorkflowBuilder(start_executor="start") + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(build_sub_workflow, name="sub") + .add_edge("start", "sub") + .build() + ) + + # Run with runtime checkpoint_storage - should fail because sub has none + with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info: + await workflow.run("hello", checkpoint_storage=storage) + + assert exc_info.value.type == "checkpoint_configuration" + + +def test_nested_sub_workflows_all_require_checkpoint() -> None: + """A -> B -> C: if A has checkpoint, B must too, and B's build validates C.""" + storage = InMemoryCheckpointStorage() + + # Inner sub-workflow without checkpoint + inner_sub = build_sub_workflow() + + # Middle workflow wrapping the inner sub - this should fail because + # middle has checkpoint but inner doesn't + with pytest.raises(WorkflowValidationError, match="sub-workflow in executor 'sub'") as exc_info: + WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor( + lambda: StartExecutor(id="start"), name="start" + ).register_executor(lambda: inner_sub, name="sub").add_edge("start", "sub").build() + + assert exc_info.value.type == "checkpoint_configuration" + + +def test_error_message_identifies_executor() -> None: + """Error message includes the executor ID of the offending sub-workflow.""" + storage = InMemoryCheckpointStorage() + custom_id_sub = WorkflowExecutor( + WorkflowBuilder(start_executor="start") + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(lambda: FinishExecutor(id="finish"), name="finish") + .add_edge("start", "finish") + .build(), + id="my_custom_executor_name", + ) + + with pytest.raises(WorkflowValidationError, match="my_custom_executor_name"): + WorkflowBuilder(start_executor="start", checkpoint_storage=storage).register_executor( + lambda: StartExecutor(id="start"), name="start" + ).register_executor(lambda: custom_id_sub, name="my_custom_executor_name").add_edge( + "start", "my_custom_executor_name" + ).build() + + +def test_sub_workflow_without_checkpoint_parent_without_checkpoint_is_fine() -> None: + """Sub-workflow has checkpoint but parent doesn't -> no error (sub manages its own checkpoints).""" + storage = InMemoryCheckpointStorage() + + workflow = ( + WorkflowBuilder(start_executor="start") + .register_executor(lambda: StartExecutor(id="start"), name="start") + .register_executor(lambda: build_sub_workflow(checkpoint_storage=storage), name="sub") + .add_edge("start", "sub") + .build() + ) + assert workflow is not None diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index 55afad880f..6ec43f33fa 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -559,7 +559,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: def _build_checkpoint_test_workflow(storage: InMemoryCheckpointStorage) -> Workflow: """Build the main workflow with checkpointing for testing.""" two_step_executor = TwoStepSubWorkflowExecutor() - sub_workflow = WorkflowBuilder(start_executor=two_step_executor).build() + sub_workflow = WorkflowBuilder(start_executor=two_step_executor, checkpoint_storage=storage).build() sub_workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow_executor") coordinator = CheckpointTestCoordinator() diff --git a/python/packages/core/tests/workflow/test_validation.py b/python/packages/core/tests/workflow/test_validation.py index ae694c8354..b63d3ce741 100644 --- a/python/packages/core/tests/workflow/test_validation.py +++ b/python/packages/core/tests/workflow/test_validation.py @@ -10,7 +10,6 @@ Executor, GraphConnectivityError, TypeCompatibilityError, - ValidationTypeEnum, WorkflowBuilder, WorkflowContext, WorkflowValidationError, @@ -95,7 +94,7 @@ def test_edge_duplication_validation_fails(): WorkflowBuilder(start_executor=executor1).add_edge(executor1, executor2).add_edge(executor1, executor2).build() assert "executor1->executor2" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.EDGE_DUPLICATION + assert exc_info.value.type == "edge_duplication" def test_type_compatibility_validation_fails(): @@ -108,7 +107,7 @@ def test_type_compatibility_validation_fails(): error = exc_info.value assert error.source_executor_id == "string_executor" assert error.target_executor_id == "int_executor" - assert error.validation_type == ValidationTypeEnum.TYPE_COMPATIBILITY + assert error.type == "type_compatibility" def test_type_compatibility_with_any_type_passes(): @@ -151,7 +150,7 @@ def test_graph_connectivity_unreachable_executors(): assert "unreachable" in str(exc_info.value).lower() assert "executor3" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.GRAPH_CONNECTIVITY + assert exc_info.value.type == "graph_connectivity" def test_graph_connectivity_isolated_executors(): @@ -191,10 +190,10 @@ def test_missing_start_executor(): def test_workflow_validation_error_base_class(): - error = WorkflowValidationError("Test message", ValidationTypeEnum.EDGE_DUPLICATION) - assert str(error) == "[EDGE_DUPLICATION] Test message" + error = WorkflowValidationError("Test message", "edge_duplication") + assert str(error) == "[edge_duplication] Test message" assert error.message == "Test message" - assert error.validation_type == ValidationTypeEnum.EDGE_DUPLICATION + assert error.type == "edge_duplication" def test_complex_workflow_validation(): @@ -464,20 +463,19 @@ async def handle_message(self, message: list[str], ctx: WorkflowContext[str]) -> assert workflow is not None -def test_validation_enum_usage() -> None: - # Test that all validation types use the enum correctly +def test_validation_type_usage() -> None: + # Test that all validation types are stored as string literals edge_error = EdgeDuplicationError("test->test") - assert edge_error.validation_type == ValidationTypeEnum.EDGE_DUPLICATION + assert edge_error.type == "edge_duplication" type_error = TypeCompatibilityError("source", "target", [str], [int]) - assert type_error.validation_type == ValidationTypeEnum.TYPE_COMPATIBILITY + assert type_error.type == "type_compatibility" graph_error = GraphConnectivityError("test message") - assert graph_error.validation_type == ValidationTypeEnum.GRAPH_CONNECTIVITY + assert graph_error.type == "graph_connectivity" - # Test enum string representation - assert str(ValidationTypeEnum.EDGE_DUPLICATION) == "ValidationTypeEnum.EDGE_DUPLICATION" - assert ValidationTypeEnum.EDGE_DUPLICATION.value == "EDGE_DUPLICATION" + error = WorkflowValidationError("test", "output_validation") + assert error.type == "output_validation" def test_handler_ctx_missing_annotation_raises() -> None: @@ -588,7 +586,7 @@ def test_output_validation_fails_for_nonexistent_executor(): assert "not present in the workflow graph" in str(exc_info.value) assert "nonexistent_executor" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.OUTPUT_VALIDATION + assert exc_info.value.type == "output_validation" def test_output_validation_fails_for_executor_without_output_types(): @@ -605,7 +603,7 @@ def test_output_validation_fails_for_executor_without_output_types(): assert "must have output type annotations defined" in str(exc_info.value) assert "no_output" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.OUTPUT_VALIDATION + assert exc_info.value.type == "output_validation" def test_output_validation_empty_list_passes(): @@ -635,7 +633,7 @@ def test_output_validation_with_direct_validate_workflow_graph(): validate_workflow_graph(edge_groups, executors, executor1, ["nonexistent"]) assert "not present in the workflow graph" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.OUTPUT_VALIDATION + assert exc_info.value.type == "output_validation" def test_output_validation_with_no_output_types_via_direct_validation(): @@ -650,7 +648,7 @@ def test_output_validation_with_no_output_types_via_direct_validation(): validate_workflow_graph(edge_groups, executors, executor1, ["no_output"]) assert "must have output type annotations defined" in str(exc_info.value) - assert exc_info.value.validation_type == ValidationTypeEnum.OUTPUT_VALIDATION + assert exc_info.value.type == "output_validation" def test_output_validation_partial_invalid_list(): @@ -668,10 +666,4 @@ def test_output_validation_partial_invalid_list(): assert "nonexistent" in str(exc_info.value) -def test_output_validation_type_enum_value(): - """Test that OUTPUT_VALIDATION is properly defined in ValidationTypeEnum.""" - assert hasattr(ValidationTypeEnum, "OUTPUT_VALIDATION") - assert ValidationTypeEnum.OUTPUT_VALIDATION.value == "OUTPUT_VALIDATION" - - # endregion diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index c975a10ae1..f398e06c3d 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -295,10 +295,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: # --------------------------------------------------------------------------- -def build_sub_workflow() -> WorkflowExecutor: +def build_sub_workflow(storage: FileCheckpointStorage) -> WorkflowExecutor: """Assemble the sub-workflow used by the parent workflow executor.""" sub_workflow = ( - WorkflowBuilder(start_executor="writer") + WorkflowBuilder(start_executor="writer", checkpoint_storage=storage) .register_executor(DraftWriter, name="writer") .register_executor(DraftReviewRouter, name="router") .register_executor(DraftFinaliser, name="finaliser") @@ -316,7 +316,7 @@ def build_parent_workflow(storage: FileCheckpointStorage) -> Workflow: return ( WorkflowBuilder(start_executor="coordinator", checkpoint_storage=storage) .register_executor(LaunchCoordinator, name="coordinator") - .register_executor(build_sub_workflow, name="sub_executor") + .register_executor(lambda: build_sub_workflow(storage), name="sub_executor") .add_edge("coordinator", "sub_executor") .add_edge("sub_executor", "coordinator") .build()