diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index f6dd035a..701d8d32 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -11,7 +11,7 @@ from bub.builtin.shell_manager import shell_manager from bub.skills import discover_skills -from bub.tools import REGISTRY, tool +from bub.tools import resolve_tool_names, tool if TYPE_CHECKING: from bub.builtin.agent import Agent @@ -263,10 +263,7 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str: else: subagent_session = param.session state = {**context.state, "session_id": subagent_session} - if param.allowed_tools: - allowed_tools = set(param.allowed_tools) - {"subagent"} - else: - allowed_tools = set(REGISTRY.keys()) - {"subagent"} + allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"}) return await agent.run( session_id=subagent_session, prompt=param.prompt, diff --git a/src/bub/tools.py b/src/bub/tools.py index a57c9663..046cc4b1 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -136,6 +136,51 @@ def _to_model_name(name: str) -> str: return name.replace(".", "_") +def _tool_name_index() -> dict[str, str]: + real_names = {tool_name.casefold(): tool_name for tool_name in REGISTRY} + alias_names = {_to_model_name(tool_name).casefold(): tool_name for tool_name in REGISTRY} + return {**alias_names, **real_names} + + +def resolve_tool_name(name: str) -> str | None: + """Resolve a user/model-provided tool name to the runtime registry name.""" + key = name.strip().casefold() + if not key: + return None + return _tool_name_index().get(key) + + +def _resolve_explicit_tool_names(names: Iterable[str]) -> tuple[set[str], set[str]]: + resolved: set[str] = set() + unknown: set[str] = set() + for name in names: + normalized_name = name.strip() + if resolved_name := resolve_tool_name(normalized_name): + resolved.add(resolved_name) + else: + unknown.add(normalized_name) + return resolved, unknown + + +def _raise_unknown_tool_names(names: set[str]) -> None: + formatted = ", ".join(sorted(repr(name) for name in names)) + raise ValueError(f"unknown tool name(s): {formatted}") + + +def resolve_tool_names(names: Iterable[str] | None = None, *, exclude: Iterable[str] = ()) -> set[str]: + """Resolve tool names from either runtime names or model-facing aliases.""" + excluded, unknown_excluded = _resolve_explicit_tool_names(exclude) + if unknown_excluded: + _raise_unknown_tool_names(unknown_excluded) + if names is None: + return set(REGISTRY) - excluded + + resolved, unknown = _resolve_explicit_tool_names(names) + if unknown: + _raise_unknown_tool_names(unknown) + return resolved - excluded + + def model_tools(tools: Iterable[Tool]) -> list[Tool]: """Helper to convert a list of Tool instances into a format accepted by LLMs.""" return [replace(tool, name=_to_model_name(tool.name)) for tool in tools] diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index d1ceeb2f..5991f58d 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -6,6 +6,7 @@ import pytest from bub.builtin.tools import run_subagent +from bub.tools import REGISTRY, tool class FakeContext: @@ -94,3 +95,50 @@ async def test_subagent_default_session_when_missing() -> None: call_kwargs = agent.run.call_args.kwargs assert call_kwargs["session_id"] == "temp/unknown" + + +@pytest.mark.asyncio +async def test_subagent_empty_allowed_tools_defaults_to_all_non_subagent_tools() -> None: + tool_name = "tests.allowed_tool_default" + REGISTRY.pop(tool_name, None) + + @tool(name=tool_name) + def allowed_tool_default() -> str: + return "ok" + + agent = FakeAgent() + ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"}) + + await run_subagent.run(prompt="task", allowed_tools=[], context=ctx) + + allowed_tools = agent.run.call_args.kwargs["allowed_tools"] + assert tool_name in allowed_tools + assert "subagent" not in allowed_tools + + +@pytest.mark.asyncio +async def test_subagent_resolves_model_tool_aliases_to_runtime_names() -> None: + tool_name = "tests.resolve_subagent" + REGISTRY.pop(tool_name, None) + + @tool(name=tool_name) + def resolve_subagent() -> str: + return "ok" + + agent = FakeAgent() + ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"}) + + await run_subagent.run(prompt="task", allowed_tools=[" tests_resolve_subagent "], context=ctx) + + assert agent.run.call_args.kwargs["allowed_tools"] == {tool_name} + + +@pytest.mark.asyncio +async def test_subagent_rejects_unknown_allowed_tools() -> None: + agent = FakeAgent() + ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"}) + + with pytest.raises(ValueError, match="tests_missing_tool"): + await run_subagent.run(prompt="task", allowed_tools=[" tests_missing_tool "], context=ctx) + + agent.run.assert_not_called() diff --git a/tests/test_tools.py b/tests/test_tools.py index 4d97e1e2..e30588fa 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import BaseModel -from bub.tools import REGISTRY, model_tools, render_tools_prompt, tool +from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names, tool class EchoInput(BaseModel): @@ -123,3 +123,33 @@ def prompt_two() -> str: def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None: assert render_tools_prompt([]) == "" + + +def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None: + dotted_name = "tests.resolve_alias" + underscored_name = "tests_with_underscore" + REGISTRY.pop(dotted_name, None) + REGISTRY.pop(underscored_name, None) + + @tool(name=dotted_name) + def resolve_alias() -> str: + return "alias" + + @tool(name=underscored_name) + def resolve_runtime_name() -> str: + return "runtime" + + assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == { + dotted_name, + underscored_name, + } + assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "}) + assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name} + + +def test_resolve_tool_names_rejects_unknown_names() -> None: + with pytest.raises(ValueError, match="tests_missing_tool"): + resolve_tool_names([" tests_missing_tool "]) + + with pytest.raises(ValueError, match="tests_missing_tool"): + resolve_tool_names(None, exclude={" tests_missing_tool "})