From c0fedbb85d5eb67d9166b37745c2f439612a354b Mon Sep 17 00:00:00 2001 From: shusingh Date: Mon, 18 May 2026 18:57:26 -0700 Subject: [PATCH] Preserve null tool arguments --- .../packages/core/agent_framework/_tools.py | 39 +++++++++- .../core/test_function_invocation_logic.py | 77 ++++++++++++++++++- python/packages/core/tests/core/test_tools.py | 30 ++++++++ 3 files changed, 141 insertions(+), 5 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 93722a8987..1577e5fafc 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -191,6 +191,37 @@ def _parse_inputs( # pyright: ignore[reportUnusedFunction] return parsed_inputs +def _model_dump_preserving_explicit_none(model: BaseModel) -> dict[str, Any]: + """Dump a model without dropping fields that were explicitly set to None.""" + dumped = model.model_dump(exclude_none=True) + _restore_explicit_none_fields(model, dumped) + return dumped + + +def _restore_explicit_none_fields(value: Any, dumped: Any) -> None: + if isinstance(value, BaseModel) and isinstance(dumped, dict): + for field_name in value.model_fields_set: + if not isinstance(field_name, str): + continue + + field_value = getattr(value, field_name, None) + if field_value is None: + dumped[field_name] = None + elif field_name in dumped: + _restore_explicit_none_fields(field_value, dumped[field_name]) + return + + if isinstance(value, Mapping) and isinstance(dumped, Mapping): + for key, item in value.items(): + if key in dumped: + _restore_explicit_none_fields(item, dumped[key]) + return + + if isinstance(value, list | tuple) and isinstance(dumped, list): + for item, dumped_item in zip(value, dumped): + _restore_explicit_none_fields(item, dumped_item) + + # region Tools @@ -635,8 +666,8 @@ async def invoke( if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) if self.input_model is not None and not self._schema_supplied: - parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump( - exclude_none=True + parsed_arguments = _model_dump_preserving_explicit_none( + self.input_model.model_validate(parsed_arguments) ) elif isinstance(arguments, BaseModel): if ( @@ -645,7 +676,7 @@ async def invoke( and not isinstance(arguments, self.input_model) ): raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - parsed_arguments = arguments.model_dump(exclude_none=True) + parsed_arguments = _model_dump_preserving_explicit_none(arguments) else: raise TypeError( f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}" @@ -1492,7 +1523,7 @@ async def _auto_invoke_function( runtime_kwargs["session"] = invocation_session try: if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: - args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) + args = _model_dump_preserving_explicit_none(tool.input_model.model_validate(parsed_args)) else: args = dict(parsed_args) args = _validate_arguments_against_schema( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 3d20a26080..b804043635 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -2,15 +2,17 @@ import asyncio from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, Literal import pytest +from pydantic import BaseModel from agent_framework import ( Agent, ChatResponse, ChatResponseUpdate, Content, + FunctionTool, Message, SupportsChatGetResponse, chat_middleware, @@ -87,6 +89,79 @@ def ai_func(arg1: str) -> str: assert response.messages[2].text == "done" +async def test_auto_function_calling_preserves_explicit_null_arguments(chat_client_base: SupportsChatGetResponse): + captured_args: list[tuple[str, str | None]] = [] + + @tool(name="get_weather", approval_mode="never_require") + def get_weather(location: str, unit: Literal["C", "F"] | None) -> str: + captured_args.append((location, unit)) + return f"{location}:{unit or 'C'}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="1", + name="get_weather", + arguments='{"location": "Seattle", "unit": null}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", contents=["done"])), + ] + + await chat_client_base.get_response( + [Message(role="user", contents=["weather in Seattle"])], + options={"tool_choice": "auto", "tools": [get_weather]}, + ) + + assert captured_args == [("Seattle", None)] + + +async def test_auto_function_calling_preserves_nested_explicit_null_arguments( + chat_client_base: SupportsChatGetResponse, +): + class WeatherOptions(BaseModel): + unit: Literal["C", "F"] | None + + class WeatherArgs(BaseModel): + location: str + options: WeatherOptions + + captured_options: list[dict[str, Any]] = [] + + def get_weather(location: str, options: dict[str, Any]) -> str: + captured_options.append(options) + return f"{location}:{options['unit'] or 'C'}" + + weather_tool = FunctionTool(name="get_weather", func=get_weather, input_model=WeatherArgs) + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="1", + name="get_weather", + arguments='{"location": "Seattle", "options": {"unit": null}}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", contents=["done"])), + ] + + await chat_client_base.get_response( + [Message(role="user", contents=["weather in Seattle"])], + options={"tool_choice": "auto", "tools": [weather_tool]}, + ) + + assert captured_options == [{"unit": None}] + + async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse): exec_counter = 0 diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index b3762bf4ef..6fbbbd8b2d 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -183,6 +183,36 @@ def search(query: str, max_results: int = 10) -> str: await search.invoke(arguments={"query": "hello", "max_results": "three"}) +async def test_tool_invoke_preserves_explicit_null_for_required_nullable_argument() -> None: + """Explicit null values are valid for required nullable tool parameters.""" + + @tool + def get_weather(location: str, unit: Literal["C", "F"] | None) -> str: + return f"{location}:{unit or 'C'}" + + result = await get_weather.invoke(arguments={"location": "Seattle", "unit": None}) + + assert result[0].text == "Seattle:C" + + +async def test_tool_invoke_preserves_nested_explicit_null_argument() -> None: + class WeatherOptions(BaseModel): + unit: Literal["C", "F"] | None + + class WeatherArgs(BaseModel): + location: str + options: WeatherOptions + + def get_weather(location: str, options: dict[str, Any]) -> str: + return f"{location}:{options['unit'] or 'C'}" + + weather_tool = FunctionTool(name="get_weather", func=get_weather, input_model=WeatherArgs) + + result = await weather_tool.invoke(arguments={"location": "Seattle", "options": {"unit": None}}) + + assert result[0].text == "Seattle:C" + + def test_tool_decorator_with_json_schema_preserves_custom_properties(): """Test schema passthrough keeps custom JSON schema properties."""