-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Python: Fix: AgentThread re-submits prior tool outputs on subsequent turns with @tool functions #3722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python: Fix: AgentThread re-submits prior tool outputs on subsequent turns with @tool functions #3722
Changes from all commits
c3cd1c2
a5f7435
28398b0
b76577e
25eb981
5d34f5f
f8e6a70
71e1810
087182a
0ca072b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,7 +399,7 @@ async def _prepare_options( | |
| **kwargs: Any, | ||
| ) -> dict[str, Any]: | ||
| """Take ChatOptions and create the specific options for Azure AI.""" | ||
| prepared_messages, instructions = self._prepare_messages_for_azure_ai(messages) | ||
| prepared_messages, instructions = self._prepare_messages_for_azure_ai(messages, options, **kwargs) | ||
| run_options = await super()._prepare_options(prepared_messages, options, **kwargs) | ||
|
|
||
| # WORKAROUND: Azure AI Projects 'create responses' API has schema divergence from OpenAI's | ||
|
|
@@ -487,17 +487,46 @@ def _get_current_conversation_id(self, options: Mapping[str, Any], **kwargs: Any | |
| """Get the current conversation ID from chat options or kwargs.""" | ||
| return options.get("conversation_id") or kwargs.get("conversation_id") or self.conversation_id | ||
|
|
||
| def _prepare_messages_for_azure_ai(self, messages: Sequence[ChatMessage]) -> tuple[list[ChatMessage], str | None]: | ||
| """Prepare input from messages and convert system/developer messages to instructions.""" | ||
| def _prepare_messages_for_azure_ai( | ||
| self, messages: Sequence[ChatMessage], options: Mapping[str, Any], **kwargs: Any | ||
| ) -> tuple[list[ChatMessage], str | None]: | ||
| """Prepare input from messages and convert system/developer messages to instructions. | ||
|
|
||
| When using previous_response_id (response chaining), filters out old function results | ||
| and assistant messages since they're already in the server-side conversation history. | ||
| Only NEW user messages should be sent. | ||
| """ | ||
| # Check if we're using previous_response_id (response chaining pattern) | ||
| conversation_id = self._get_current_conversation_id(options, **kwargs) | ||
| use_response_chaining = conversation_id is not None and conversation_id.startswith("resp_") | ||
|
|
||
| result: list[ChatMessage] = [] | ||
| instructions_list: list[str] = [] | ||
| instructions: str | None = None | ||
|
|
||
| # When using response chaining, find the index of the last assistant message | ||
| # Messages after that are "new" and should be included | ||
| last_assistant_idx = -1 | ||
| if use_response_chaining: | ||
| for i in range(len(messages) - 1, -1, -1): | ||
| if messages[i].role == "assistant": | ||
| last_assistant_idx = i | ||
| break | ||
|
|
||
| # System/developer messages are turned into instructions, since there is no such message roles in Azure AI. | ||
| for message in messages: | ||
| for idx, message in enumerate(messages): | ||
| if message.role in ["system", "developer"]: | ||
| for text_content in [content for content in message.contents if content.type == "text"]: | ||
| instructions_list.append(text_content.text) # type: ignore[arg-type] | ||
| elif use_response_chaining: | ||
| # When using response chaining, only include messages after the last assistant message | ||
| # These are the "new" messages from the current turn | ||
| if idx > last_assistant_idx: | ||
| # Also filter out function result messages | ||
| has_function_result = any(content.type == "function_result" for content in message.contents) | ||
| if not has_function_result: | ||
| result.append(message) | ||
|
Comment on lines
+507
to
+528
|
||
| # Skip all messages at or before the last assistant message (already in server history) | ||
| else: | ||
| result.append(message) | ||
|
Comment on lines
+517
to
531
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -303,7 +303,7 @@ async def test_prepare_messages_for_azure_ai_with_system_messages( | |
| ChatMessage(role="assistant", contents=[Content.from_text(text="System response")]), | ||
| ] | ||
|
|
||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore | ||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages, {}) # type: ignore | ||
|
|
||
| assert len(result_messages) == 2 | ||
| assert result_messages[0].role == "user" | ||
|
|
@@ -322,12 +322,89 @@ async def test_prepare_messages_for_azure_ai_no_system_messages( | |
| ChatMessage(role="assistant", contents=[Content.from_text(text="Hi there!")]), | ||
| ] | ||
|
|
||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore | ||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages, {}) # type: ignore | ||
|
|
||
| assert len(result_messages) == 2 | ||
| assert instructions is None | ||
|
|
||
|
|
||
| async def test_prepare_messages_filters_old_function_results_with_previous_response_id( | ||
| mock_project_client: MagicMock, | ||
| ) -> None: | ||
| """Test _prepare_messages_for_azure_ai filters old function results when using previous_response_id.""" | ||
| client = create_test_azure_ai_client(mock_project_client) | ||
|
|
||
| # Simulate a multi-turn conversation with function calls | ||
| messages = [ | ||
| # Turn 1 - user asks a question | ||
| ChatMessage(role="user", contents=[Content.from_text(text="Calculate 15% tip on $85")]), | ||
| # Turn 1 - assistant makes a function call | ||
| ChatMessage( | ||
| role="assistant", | ||
| contents=[ | ||
| Content.from_function_call( | ||
| call_id="call_123", name="calculate_tip", arguments='{"bill_amount": 85, "tip_percent": 15}' | ||
| ) | ||
| ], | ||
| ), | ||
| # Turn 1 - function result | ||
| ChatMessage( | ||
| role="user", | ||
| contents=[Content.from_function_result(call_id="call_123", result="Tip: $12.75, Total: $97.75")], | ||
| ), | ||
| # Turn 1 - assistant responds with text | ||
| ChatMessage(role="assistant", contents=[Content.from_text(text="The tip is $12.75")]), | ||
| # Turn 2 - NEW user message | ||
| ChatMessage(role="user", contents=[Content.from_text(text="Now calculate 20% tip on $85")]), | ||
| ] | ||
|
|
||
| # Test WITH previous_response_id (should filter to only new user message) | ||
| options = {"conversation_id": "resp_turn1"} | ||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages, options) # type: ignore | ||
|
|
||
| # Should only have the NEW user message from turn 2 | ||
| assert len(result_messages) == 1 | ||
| assert result_messages[0].role == "user" | ||
| assert any(c.type == "text" for c in result_messages[0].contents) | ||
| # Should not have function results | ||
| assert not any(c.type == "function_result" for c in result_messages[0].contents) | ||
| assert instructions is None | ||
|
|
||
|
|
||
| async def test_prepare_messages_includes_all_without_previous_response_id( | ||
| mock_project_client: MagicMock, | ||
| ) -> None: | ||
| """Test _prepare_messages_for_azure_ai includes all messages without previous_response_id.""" | ||
| client = create_test_azure_ai_client(mock_project_client) | ||
|
|
||
| # Same messages as previous test | ||
| messages = [ | ||
| ChatMessage(role="user", contents=[Content.from_text(text="Calculate 15% tip on $85")]), | ||
| ChatMessage( | ||
| role="assistant", | ||
| contents=[ | ||
| Content.from_function_call( | ||
| call_id="call_123", name="calculate_tip", arguments='{"bill_amount": 85, "tip_percent": 15}' | ||
| ) | ||
| ], | ||
| ), | ||
| ChatMessage( | ||
| role="user", | ||
| contents=[Content.from_function_result(call_id="call_123", result="Tip: $12.75, Total: $97.75")], | ||
| ), | ||
| ChatMessage(role="assistant", contents=[Content.from_text(text="The tip is $12.75")]), | ||
| ChatMessage(role="user", contents=[Content.from_text(text="Now calculate 20% tip on $85")]), | ||
| ] | ||
|
|
||
| # Test WITHOUT previous_response_id (should include all messages) | ||
| options: dict[str, Any] = {} | ||
| result_messages, instructions = client._prepare_messages_for_azure_ai(messages, options) # type: ignore | ||
|
|
||
| # Should have all non-system messages (5 in this case) | ||
| assert len(result_messages) == 5 | ||
| assert instructions is None | ||
|
|
||
|
Comment on lines
+331
to
+406
|
||
|
|
||
| def test_transform_input_for_azure_ai(mock_project_client: MagicMock) -> None: | ||
| """Test _transform_input_for_azure_ai adds required fields for Azure AI schema. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,223 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| """Test multi-turn conversations with function tools in Azure AI.""" | ||
|
|
||
| from typing import Annotated | ||
| from unittest.mock import AsyncMock, MagicMock | ||
|
|
||
| import pytest | ||
| from agent_framework import tool | ||
| from azure.ai.projects.models import PromptAgentDefinition | ||
| from pydantic import Field | ||
|
|
||
| from agent_framework_azure_ai import AzureAIProjectAgentProvider | ||
|
|
||
|
|
||
| @tool(approval_mode="never_require") | ||
| def calculate_tip( | ||
| bill_amount: Annotated[float, Field(description="Bill amount in dollars")], | ||
| tip_percent: Annotated[float, Field(description="Tip percentage")], | ||
| ) -> str: | ||
| """Calculate tip amount for a bill.""" | ||
| tip = bill_amount * (tip_percent / 100) | ||
| return f"Tip: ${tip:.2f}, Total: ${bill_amount + tip:.2f}" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_multi_turn_function_tools_does_not_resubmit_old_results(): | ||
| """Test that multi-turn conversations don't re-submit old function call results.""" | ||
| # Setup mock project client | ||
| mock_project_client = AsyncMock() | ||
| mock_agents = AsyncMock() | ||
| mock_project_client.agents = mock_agents | ||
|
|
||
| # Mock agent creation | ||
| mock_agent_version = MagicMock() | ||
| mock_agent_version.id = "agent_id_123" | ||
| mock_agent_version.name = "tip-calculator" | ||
| mock_agent_version.version = "v1" | ||
| mock_agent_version.description = None | ||
| mock_agent_version.definition = PromptAgentDefinition( | ||
| model="gpt-4", | ||
| instructions="Use the calculate_tip tool to help with calculations.", | ||
| tools=[], | ||
| ) | ||
| mock_agents.create_version = AsyncMock(return_value=mock_agent_version) | ||
|
|
||
| # Mock OpenAI client that tracks requests | ||
| requests_made = [] | ||
|
|
||
| def mock_create_response(**kwargs): | ||
| """Mock response creation that tracks inputs.""" | ||
| requests_made.append(kwargs) | ||
|
|
||
| # Simulate a response with function call on turn 1 | ||
| if len(requests_made) == 1: | ||
| mock_response = MagicMock() | ||
| mock_response.id = "resp_turn1" | ||
| mock_response.created_at = 1234567890 | ||
| mock_response.model = "gpt-4" | ||
| mock_response.usage = None | ||
| mock_response.metadata = {} | ||
|
|
||
| # Return a function call | ||
| mock_function_call = MagicMock() | ||
| mock_function_call.type = "function_call" | ||
| mock_function_call.id = "fc_call_123" | ||
| mock_function_call.call_id = "call_123" | ||
| mock_function_call.name = "calculate_tip" | ||
| mock_function_call.arguments = '{"bill_amount": 85, "tip_percent": 15}' | ||
|
|
||
| mock_response.output = [mock_function_call] | ||
| return mock_response | ||
| # Turn 2: Return a text response | ||
| mock_response = MagicMock() | ||
| mock_response.id = "resp_turn2" | ||
| mock_response.created_at = 1234567891 | ||
| mock_response.model = "gpt-4" | ||
| mock_response.usage = None | ||
| mock_response.metadata = {} | ||
|
|
||
| mock_message = MagicMock() | ||
| mock_message.type = "message" | ||
| mock_text = MagicMock() | ||
| mock_text.type = "output_text" | ||
| mock_text.text = "The 20% tip is calculated." | ||
| mock_message.content = [mock_text] | ||
|
|
||
| mock_response.output = [mock_message] | ||
| return mock_response | ||
|
|
||
| mock_openai_client = MagicMock() | ||
| mock_openai_client.responses = MagicMock() | ||
| mock_openai_client.responses.create = AsyncMock(side_effect=mock_create_response) | ||
| mock_project_client.get_openai_client = MagicMock(return_value=mock_openai_client) | ||
|
|
||
| # Create provider and agent | ||
| provider = AzureAIProjectAgentProvider(project_client=mock_project_client, model="gpt-4") | ||
| agent = await provider.create_agent( | ||
| name="tip-calculator", | ||
| instructions="Use the calculate_tip tool to help with calculations.", | ||
| tools=[calculate_tip], | ||
| ) | ||
|
|
||
| # Single thread for multi-turn (BUG TRIGGER) | ||
| thread = agent.get_new_thread() | ||
|
|
||
| # Turn 1: Should work fine | ||
| result1 = await agent.run("Calculate 15% tip on an $85 bill", thread=thread) | ||
| assert result1 is not None | ||
|
|
||
| # Check Turn 1 request - should have the user message | ||
| turn1_request = requests_made[0] | ||
| turn1_input = turn1_request["input"] | ||
| assert any(item.get("role") == "user" for item in turn1_input if isinstance(item, dict)) | ||
|
|
||
| # Turn 2: Should NOT re-submit function call results from Turn 1 | ||
| result2 = await agent.run("Now calculate 20% tip on the same $85 bill", thread=thread) | ||
| assert result2 is not None | ||
|
|
||
| # Check Turn 2 request - should NOT have function_call_output from Turn 1 | ||
| turn2_request = requests_made[-1] # Last request made (after function execution) | ||
| turn2_input = turn2_request["input"] | ||
|
|
||
| # The key assertion: Turn 2 should only have NEW function outputs (from turn 2's function calls) | ||
| # If it has function outputs from turn 1, that's the bug we're fixing | ||
| # Since turn 2 likely also has a function call, we need to check that old outputs aren't there | ||
|
|
||
| # A more robust check: verify that turn 2's input doesn't contain the call_id from turn 1 | ||
| turn1_call_id = "call_123" | ||
| has_old_function_output = any( | ||
| item.get("type") == "function_call_output" and item.get("call_id") == turn1_call_id | ||
| for item in turn2_input | ||
| if isinstance(item, dict) | ||
| ) | ||
|
|
||
| assert not has_old_function_output, ( | ||
| "Turn 2 should not re-submit function_call_output from Turn 1. " | ||
| "Found old function output with call_id from Turn 1." | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_multi_turn_with_previous_response_id_filters_old_messages(): | ||
| """Test that when using previous_response_id, old function results are filtered.""" | ||
| # Setup mock project client | ||
| mock_project_client = AsyncMock() | ||
| mock_agents = AsyncMock() | ||
| mock_project_client.agents = mock_agents | ||
|
|
||
| # Mock agent creation | ||
| mock_agent_version = MagicMock() | ||
| mock_agent_version.id = "agent_id_123" | ||
| mock_agent_version.name = "test-agent" | ||
| mock_agent_version.version = "v1" | ||
| mock_agent_version.description = None | ||
| mock_agent_version.definition = PromptAgentDefinition( | ||
| model="gpt-4", | ||
| instructions="You are a helpful assistant.", | ||
| tools=[], | ||
| ) | ||
| mock_agents.create_version = AsyncMock(return_value=mock_agent_version) | ||
|
|
||
| # Mock OpenAI client | ||
| requests_made = [] | ||
|
|
||
| def mock_create_response(**kwargs): | ||
| """Mock response creation.""" | ||
| requests_made.append(kwargs) | ||
| mock_response = MagicMock() | ||
| mock_response.id = f"resp_turn{len(requests_made)}" | ||
| mock_response.created_at = 1234567890 + len(requests_made) | ||
| mock_response.model = "gpt-4" | ||
| mock_response.usage = None | ||
| mock_response.metadata = {} | ||
| mock_message = MagicMock() | ||
| mock_message.type = "message" | ||
| mock_text = MagicMock() | ||
| mock_text.type = "output_text" | ||
| mock_text.text = f"Response {len(requests_made)}" | ||
| mock_message.content = [mock_text] | ||
| mock_response.output = [mock_message] | ||
| return mock_response | ||
|
|
||
| mock_openai_client = MagicMock() | ||
| mock_openai_client.responses = MagicMock() | ||
| mock_openai_client.responses.create = AsyncMock(side_effect=mock_create_response) | ||
| mock_project_client.get_openai_client = MagicMock(return_value=mock_openai_client) | ||
|
|
||
| # Create provider and agent | ||
| provider = AzureAIProjectAgentProvider(project_client=mock_project_client, model="gpt-4") | ||
| agent = await provider.create_agent( | ||
| name="test-agent", | ||
| instructions="You are a helpful assistant.", | ||
| tools=[calculate_tip], | ||
| ) | ||
|
|
||
| # Create a thread starting with a service_thread_id (simulating a previous response) | ||
| # This avoids the message_store/service_thread_id conflict | ||
| thread = agent.get_new_thread() | ||
| # Simulate that turn 1 has already completed and returned resp_turn1 | ||
| # We manually set the internal state to simulate this | ||
|
|
||
| # Use the internal property to bypass the setter validation | ||
| thread._service_thread_id = "resp_turn1" | ||
|
|
||
| # Turn 2: New user message | ||
| # This turn should only send the new user message, not any messages from turn 1 | ||
| result2 = await agent.run("Now calculate 20% tip", thread=thread) | ||
| assert result2 is not None | ||
|
|
||
| # Check that turn 2 request has previous_response_id set | ||
| turn2_request = requests_made[0] | ||
| assert "previous_response_id" in turn2_request | ||
| assert turn2_request["previous_response_id"] == "resp_turn1" | ||
|
|
||
| # Check that turn 2 input doesn't contain old function results | ||
| # Since we're using service_thread_id, the messages are managed server-side | ||
| # and only the new user message should be in the request | ||
| turn2_input = turn2_request["input"] | ||
|
|
||
| # Turn 2 should only have the NEW user message | ||
| user_messages = [item for item in turn2_input if isinstance(item, dict) and item.get("role") == "user"] | ||
| assert len(user_messages) == 1, "Turn 2 should only have the NEW user message" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a message contains both function_result content AND other content types (e.g., text), the entire message is filtered out if any content is a function_result. This could lead to loss of non-function-result content. Consider filtering at the content level rather than the message level, or creating a new message with only the non-function-result content items.