Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,37 @@ def _parse_inputs( # pyright: ignore[reportUnusedFunction]
return parsed_inputs


def _model_dump_preserving_explicit_none(model: BaseModel) -> dict[str, Any]:
"""Dump a model without dropping fields that were explicitly set to None."""
dumped = model.model_dump(exclude_none=True)
_restore_explicit_none_fields(model, dumped)
return dumped
Comment on lines +194 to +198
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Updated the helper to recursively restore explicitly provided null values in nested Pydantic models, mappings, and lists/tuples, and added nested regression coverage for both direct invoke and automatic function calling.



def _restore_explicit_none_fields(value: Any, dumped: Any) -> None:
if isinstance(value, BaseModel) and isinstance(dumped, dict):
for field_name in value.model_fields_set:
if not isinstance(field_name, str):
continue

field_value = getattr(value, field_name, None)
if field_value is None:
dumped[field_name] = None
elif field_name in dumped:
_restore_explicit_none_fields(field_value, dumped[field_name])
return

if isinstance(value, Mapping) and isinstance(dumped, Mapping):
for key, item in value.items():
if key in dumped:
_restore_explicit_none_fields(item, dumped[key])
return

if isinstance(value, list | tuple) and isinstance(dumped, list):
for item, dumped_item in zip(value, dumped):
_restore_explicit_none_fields(item, dumped_item)


# region Tools


Expand Down Expand Up @@ -635,8 +666,8 @@ async def invoke(
if isinstance(arguments, Mapping):
parsed_arguments = dict(arguments)
if self.input_model is not None and not self._schema_supplied:
parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump(
exclude_none=True
parsed_arguments = _model_dump_preserving_explicit_none(
self.input_model.model_validate(parsed_arguments)
)
elif isinstance(arguments, BaseModel):
if (
Expand All @@ -645,7 +676,7 @@ async def invoke(
and not isinstance(arguments, self.input_model)
):
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
parsed_arguments = arguments.model_dump(exclude_none=True)
parsed_arguments = _model_dump_preserving_explicit_none(arguments)
else:
raise TypeError(
f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}"
Expand Down Expand Up @@ -1492,7 +1523,7 @@ async def _auto_invoke_function(
runtime_kwargs["session"] = invocation_session
try:
if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None:
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
args = _model_dump_preserving_explicit_none(tool.input_model.model_validate(parsed_args))
else:
args = dict(parsed_args)
args = _validate_arguments_against_schema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

import asyncio
from collections.abc import Awaitable, Callable
from typing import Any
from typing import Any, Literal

import pytest
from pydantic import BaseModel

from agent_framework import (
Agent,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionTool,
Message,
SupportsChatGetResponse,
chat_middleware,
Expand Down Expand Up @@ -87,6 +89,79 @@ def ai_func(arg1: str) -> str:
assert response.messages[2].text == "done"


async def test_auto_function_calling_preserves_explicit_null_arguments(chat_client_base: SupportsChatGetResponse):
captured_args: list[tuple[str, str | None]] = []

@tool(name="get_weather", approval_mode="never_require")
def get_weather(location: str, unit: Literal["C", "F"] | None) -> str:
captured_args.append((location, unit))
return f"{location}:{unit or 'C'}"

chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="1",
name="get_weather",
arguments='{"location": "Seattle", "unit": null}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", contents=["done"])),
]

await chat_client_base.get_response(
[Message(role="user", contents=["weather in Seattle"])],
options={"tool_choice": "auto", "tools": [get_weather]},
)

assert captured_args == [("Seattle", None)]


async def test_auto_function_calling_preserves_nested_explicit_null_arguments(
chat_client_base: SupportsChatGetResponse,
):
class WeatherOptions(BaseModel):
unit: Literal["C", "F"] | None

class WeatherArgs(BaseModel):
location: str
options: WeatherOptions

captured_options: list[dict[str, Any]] = []

def get_weather(location: str, options: dict[str, Any]) -> str:
captured_options.append(options)
return f"{location}:{options['unit'] or 'C'}"

weather_tool = FunctionTool(name="get_weather", func=get_weather, input_model=WeatherArgs)
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="1",
name="get_weather",
arguments='{"location": "Seattle", "options": {"unit": null}}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", contents=["done"])),
]

await chat_client_base.get_response(
[Message(role="user", contents=["weather in Seattle"])],
options={"tool_choice": "auto", "tools": [weather_tool]},
)

assert captured_options == [{"unit": None}]


async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse):
exec_counter = 0

Expand Down
30 changes: 30 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,36 @@ def search(query: str, max_results: int = 10) -> str:
await search.invoke(arguments={"query": "hello", "max_results": "three"})


async def test_tool_invoke_preserves_explicit_null_for_required_nullable_argument() -> None:
"""Explicit null values are valid for required nullable tool parameters."""

@tool
def get_weather(location: str, unit: Literal["C", "F"] | None) -> str:
return f"{location}:{unit or 'C'}"

result = await get_weather.invoke(arguments={"location": "Seattle", "unit": None})

assert result[0].text == "Seattle:C"


async def test_tool_invoke_preserves_nested_explicit_null_argument() -> None:
class WeatherOptions(BaseModel):
unit: Literal["C", "F"] | None

class WeatherArgs(BaseModel):
location: str
options: WeatherOptions

def get_weather(location: str, options: dict[str, Any]) -> str:
return f"{location}:{options['unit'] or 'C'}"

weather_tool = FunctionTool(name="get_weather", func=get_weather, input_model=WeatherArgs)

result = await weather_tool.invoke(arguments={"location": "Seattle", "options": {"unit": None}})

assert result[0].text == "Seattle:C"


def test_tool_decorator_with_json_schema_preserves_custom_properties():
"""Test schema passthrough keeps custom JSON schema properties."""

Expand Down