Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/google/adk/flows/llm_flows/_output_schema_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@
from __future__ import annotations

import json
import logging
from typing import AsyncGenerator

from google.genai import types
from typing_extensions import override

from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...tools.set_model_response_tool import SetModelResponseTool
from ...utils._schema_utils import is_basemodel_schema
from ...utils.output_schema_utils import can_use_output_schema_with_tools
from ._base_llm_processor import BaseLlmRequestProcessor

logger = logging.getLogger(__name__)

# Max tool rounds before forcing set_model_response (N-1) or terminating (N).
_MAX_TOOL_ROUNDS = 25


class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
"""Processor that handles output schema for agents with tools."""
Expand All @@ -36,8 +44,6 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent

agent = invocation_context.agent

# Check if we need the processor: output_schema + tools + cannot use output
Expand All @@ -49,20 +55,56 @@ async def run_async(
):
return

# Count how many tool rounds have occurred in this invocation.
tool_rounds = sum(
1
for e in invocation_context._get_events(
current_invocation=True, current_branch=True
)
if e.get_function_responses()
)

# Terminate the invocation if the model never calls set_model_response.
if tool_rounds >= _MAX_TOOL_ROUNDS:
logger.error(
'Tool execution reached %d rounds without producing structured'
' output via set_model_response. Breaking loop to prevent'
' runaway API costs.',
tool_rounds,
)
invocation_context.end_invocation = True
return

# Add the set_model_response tool to handle structured output
set_response_tool = SetModelResponseTool(agent.output_schema)
llm_request.append_tools([set_response_tool])

# Add instruction about using the set_model_response tool
instruction = (
'IMPORTANT: You have access to other tools, but you must provide '
'your final response using the set_model_response tool with the '
'required structured format. After using any other tools needed '
'to complete the task, always call set_model_response with your '
'final answer in the specified schema format.'
)
# Primitive types (str, int, etc.) produce a trivial tool signature
# that flash models tend to ignore use a stronger instruction.
if is_basemodel_schema(agent.output_schema):
instruction = (
'After completing any needed tool calls, provide your final'
' response by calling set_model_response with the required'
' fields.'
)
else:
instruction = (
'IMPORTANT: After using any needed tools, you MUST call'
' set_model_response to provide your final answer.'
' This is required to complete the task.'
)
llm_request.append_instructions([instruction])

# On round N-1, restrict the model to only call set_model_response.
if tool_rounds >= _MAX_TOOL_ROUNDS - 1:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY,
allowed_function_names=['set_model_response'],
)
)

return
yield # Generator requires yield statement in function body.

Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ async def _postprocess_live(
)
)
yield final_event
return # Skip further processing after set_model_response.

async def _postprocess_run_processors_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
Expand Down Expand Up @@ -1091,6 +1092,7 @@ async def _postprocess_handle_function_calls_async(
)
)
yield final_event
return # Skip transfer_to_agent after set_model_response.
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
Expand Down
131 changes: 131 additions & 0 deletions tests/integration/test_output_schema_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Integration test for output_schema + tools behavior.

Requires GOOGLE_API_KEY or Vertex AI credentials.
Run with: python -m pytest tests/integration/test_output_schema_with_tools.py -v -s
"""

import os
import time

from google.adk.agents.llm_agent import LlmAgent
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
import pytest


class AnalysisResult(BaseModel):
summary: str = Field(description='Brief summary of the analysis')
confidence: float = Field(description='Confidence score between 0 and 1')


def search_data(query: str) -> str:
"""Search for data based on the query."""
return f'Found data for: {query}. Revenue is $1M, growth is 15%.'


def calculate_metric(metric_name: str, value: float) -> str:
"""Calculate a business metric."""
return f'{metric_name}: {value * 1.1:.2f} (adjusted)'


# Skip if no API key is configured.
skip_no_api_key = pytest.mark.skipif(
not os.environ.get('GOOGLE_API_KEY')
and not os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'),
reason='No Gemini API key or Vertex AI configured',
)


@skip_no_api_key
@pytest.mark.asyncio
async def test_basemodel_schema_with_tools():
"""Test that BaseModel output_schema + tools produces structured output."""
agent = LlmAgent(
name='analyst',
model='gemini-2.5-flash',
instruction=(
'Analyze the query using the available tools, then return'
' structured output.'
),
output_schema=AnalysisResult,
tools=[search_data, calculate_metric],
)

session_service = InMemorySessionService()
runner = Runner(
agent=agent, app_name='test_app', session_service=session_service
)
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)

events = []
start = time.time()

async for event in runner.run_async(
user_id='test_user',
session_id=session.id,
new_message=types.Content(
role='user',
parts=[types.Part(text='Analyze Q1 revenue performance')],
),
):
events.append(event)

elapsed = time.time() - start

# Should complete within a reasonable time (not infinite loop).
assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop'

# Should have at least one event with structured output.
final_texts = [
e.content.parts[0].text
for e in events
if e.content and e.content.parts and e.content.parts[0].text
]
assert len(final_texts) > 0, 'No text output produced'
print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events')
print(f'Final output: {final_texts[-1][:200]}')


@skip_no_api_key
@pytest.mark.asyncio
async def test_str_schema_with_tools():
"""Test that str output_schema + tools produces output (not infinite loop)."""
agent = LlmAgent(
name='analyst',
model='gemini-2.5-flash',
instruction='Search for the data, then provide a brief text summary.',
output_schema=str,
tools=[search_data],
)

session_service = InMemorySessionService()
runner = Runner(
agent=agent, app_name='test_app', session_service=session_service
)
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)

events = []
start = time.time()

async for event in runner.run_async(
user_id='test_user',
session_id=session.id,
new_message=types.Content(
role='user',
parts=[types.Part(text='What is the Q1 revenue?')],
),
):
events.append(event)

elapsed = time.time() - start

assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop'
assert len(events) > 0, 'No events produced'
print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events')
Loading
Loading