Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def decode_checkpoint_value(value: Any) -> Any:
cls = None

if cls is not None:
# Verify the class actually supports the model protocol
if not _class_supports_model_protocol(cls):
logger.debug(f"Class {type_key} does not support model protocol; returning raw value")
return decoded_payload
if strategy == "to_dict" and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
Expand All @@ -169,6 +173,10 @@ def decode_checkpoint_value(value: Any) -> Any:
if module is None:
module = importlib.import_module(module_name)
cls_dc: Any = getattr(module, class_name)
# Verify the class is actually a dataclass type (not an instance)
if not isinstance(cls_dc, type) or not is_dataclass(cls_dc):
logger.debug(f"Class {type_key_dc} is not a dataclass type; returning raw value")
return decoded_raw
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
if constructed is not None:
return constructed
Expand All @@ -188,20 +196,30 @@ def decode_checkpoint_value(value: Any) -> Any:
return value


def _class_supports_model_protocol(cls: type[Any]) -> bool:
"""Check if a class type supports the model serialization protocol.

Checks for pairs of serialization/deserialization methods:
- to_dict/from_dict
- to_json/from_json
"""
has_to_dict = hasattr(cls, "to_dict") and callable(getattr(cls, "to_dict", None))
has_from_dict = hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict", None))

has_to_json = hasattr(cls, "to_json") and callable(getattr(cls, "to_json", None))
has_from_json = hasattr(cls, "from_json") and callable(getattr(cls, "from_json", None))

return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)


def _supports_model_protocol(obj: object) -> bool:
"""Detect objects that expose dictionary serialization hooks."""
try:
obj_type: type[Any] = type(obj)
except Exception:
return False

has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))

has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))

return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
return _class_supports_model_protocol(obj_type)


def _import_qualified_name(qualname: str) -> type[Any] | None:
Expand Down
110 changes: 110 additions & 0 deletions python/packages/core/tests/workflow/test_checkpoint_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from dataclasses import dataclass # noqa: I001
from typing import Any, cast


from agent_framework._workflows._checkpoint_encoding import (
DATACLASS_MARKER,
MODEL_MARKER,
decode_checkpoint_value,
encode_checkpoint_value,
)
Expand Down Expand Up @@ -126,3 +129,110 @@ def test_encode_decode_nested_structures() -> None:
assert response.data == "first response"
assert isinstance(response.original_request, SampleRequest)
assert response.original_request.request_id == "req-1"


def test_encode_allows_marker_key_without_value_key() -> None:
"""Test that encoding a dict with only the marker key (no 'value') is allowed."""
dict_with_marker_only = {
MODEL_MARKER: "some.module:FakeClass",
"other_key": "test",
}
encoded = encode_checkpoint_value(dict_with_marker_only)
assert MODEL_MARKER in encoded
assert "other_key" in encoded


def test_encode_allows_value_key_without_marker_key() -> None:
"""Test that encoding a dict with only 'value' key (no marker) is allowed."""
dict_with_value_only = {
"value": {"data": "test"},
"other_key": "test",
}
encoded = encode_checkpoint_value(dict_with_value_only)
assert "value" in encoded
assert "other_key" in encoded


def test_encode_allows_marker_with_value_key() -> None:
"""Test that encoding a dict with marker and 'value' keys is allowed.

This is allowed because legitimate encoded data may contain these keys,
and security is enforced at deserialization time by validating class types.
"""
dict_with_both = {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
"strategy": "to_dict",
}
encoded = encode_checkpoint_value(dict_with_both)
assert MODEL_MARKER in encoded
assert "value" in encoded


class NotADataclass:
"""A regular class that is not a dataclass."""

def __init__(self, value: str) -> None:
self.value = value

def get_value(self) -> str:
return self.value


class NotAModel:
"""A regular class that does not support the model protocol."""

def __init__(self, value: str) -> None:
self.value = value

def get_value(self) -> str:
return self.value


def test_decode_rejects_non_dataclass_with_dataclass_marker() -> None:
"""Test that decode returns raw value when marked class is not a dataclass."""
# Manually construct a payload that claims NotADataclass is a dataclass
fake_payload = {
DATACLASS_MARKER: f"{NotADataclass.__module__}:{NotADataclass.__name__}",
"value": {"value": "test_value"},
}

decoded = decode_checkpoint_value(fake_payload)

# Should return the raw decoded value, not an instance of NotADataclass
assert isinstance(decoded, dict)
assert decoded["value"] == "test_value"


def test_decode_rejects_non_model_with_model_marker() -> None:
"""Test that decode returns raw value when marked class doesn't support model protocol."""
# Manually construct a payload that claims NotAModel supports the model protocol
fake_payload = {
MODEL_MARKER: f"{NotAModel.__module__}:{NotAModel.__name__}",
"strategy": "to_dict",
"value": {"value": "test_value"},
}

decoded = decode_checkpoint_value(fake_payload)

# Should return the raw decoded value, not an instance of NotAModel
assert isinstance(decoded, dict)
assert decoded["value"] == "test_value"


def test_encode_allows_nested_dict_with_marker_keys() -> None:
"""Test that encoding allows nested dicts containing marker patterns.

Security is enforced at deserialization time, not serialization time,
so legitimate encoded data can contain markers at any nesting level.
"""
nested_data = {
"outer": {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
}
}

encoded = encode_checkpoint_value(nested_data)
assert "outer" in encoded
assert MODEL_MARKER in encoded["outer"]
Loading
Loading