Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/bub/builtin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from bub.builtin.shell_manager import shell_manager
from bub.skills import discover_skills
from bub.tools import REGISTRY, tool
from bub.tools import resolve_tool_names, tool

if TYPE_CHECKING:
from bub.builtin.agent import Agent
Expand Down Expand Up @@ -263,10 +263,7 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str:
else:
subagent_session = param.session
state = {**context.state, "session_id": subagent_session}
if param.allowed_tools:
allowed_tools = set(param.allowed_tools) - {"subagent"}
else:
allowed_tools = set(REGISTRY.keys()) - {"subagent"}
allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"})
return await agent.run(
session_id=subagent_session,
prompt=param.prompt,
Expand Down
45 changes: 45 additions & 0 deletions src/bub/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ def _to_model_name(name: str) -> str:
return name.replace(".", "_")


def _tool_name_index() -> dict[str, str]:
real_names = {tool_name.casefold(): tool_name for tool_name in REGISTRY}
alias_names = {_to_model_name(tool_name).casefold(): tool_name for tool_name in REGISTRY}
return {**alias_names, **real_names}


def resolve_tool_name(name: str) -> str | None:
"""Resolve a user/model-provided tool name to the runtime registry name."""
key = name.strip().casefold()
if not key:
return None
return _tool_name_index().get(key)


def _resolve_explicit_tool_names(names: Iterable[str]) -> tuple[set[str], set[str]]:
resolved: set[str] = set()
unknown: set[str] = set()
for name in names:
normalized_name = name.strip()
if resolved_name := resolve_tool_name(normalized_name):
resolved.add(resolved_name)
else:
unknown.add(normalized_name)
return resolved, unknown


def _raise_unknown_tool_names(names: set[str]) -> None:
formatted = ", ".join(sorted(repr(name) for name in names))
raise ValueError(f"unknown tool name(s): {formatted}")


def resolve_tool_names(names: Iterable[str] | None = None, *, exclude: Iterable[str] = ()) -> set[str]:
"""Resolve tool names from either runtime names or model-facing aliases."""
excluded, unknown_excluded = _resolve_explicit_tool_names(exclude)
if unknown_excluded:
_raise_unknown_tool_names(unknown_excluded)
if names is None:
return set(REGISTRY) - excluded

resolved, unknown = _resolve_explicit_tool_names(names)
if unknown:
_raise_unknown_tool_names(unknown)
return resolved - excluded


def model_tools(tools: Iterable[Tool]) -> list[Tool]:
"""Helper to convert a list of Tool instances into a format accepted by LLMs."""
return [replace(tool, name=_to_model_name(tool.name)) for tool in tools]
Expand Down
48 changes: 48 additions & 0 deletions tests/test_subagent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from bub.builtin.tools import run_subagent
from bub.tools import REGISTRY, tool


class FakeContext:
Expand Down Expand Up @@ -94,3 +95,50 @@ async def test_subagent_default_session_when_missing() -> None:

call_kwargs = agent.run.call_args.kwargs
assert call_kwargs["session_id"] == "temp/unknown"


@pytest.mark.asyncio
async def test_subagent_empty_allowed_tools_defaults_to_all_non_subagent_tools() -> None:
tool_name = "tests.allowed_tool_default"
REGISTRY.pop(tool_name, None)

@tool(name=tool_name)
def allowed_tool_default() -> str:
return "ok"

agent = FakeAgent()
ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"})

await run_subagent.run(prompt="task", allowed_tools=[], context=ctx)

allowed_tools = agent.run.call_args.kwargs["allowed_tools"]
assert tool_name in allowed_tools
assert "subagent" not in allowed_tools


@pytest.mark.asyncio
async def test_subagent_resolves_model_tool_aliases_to_runtime_names() -> None:
tool_name = "tests.resolve_subagent"
REGISTRY.pop(tool_name, None)

@tool(name=tool_name)
def resolve_subagent() -> str:
return "ok"

agent = FakeAgent()
ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"})

await run_subagent.run(prompt="task", allowed_tools=[" tests_resolve_subagent "], context=ctx)

assert agent.run.call_args.kwargs["allowed_tools"] == {tool_name}


@pytest.mark.asyncio
async def test_subagent_rejects_unknown_allowed_tools() -> None:
agent = FakeAgent()
ctx = FakeContext({"_runtime_agent": agent, "session_id": "user/abc"})

with pytest.raises(ValueError, match="tests_missing_tool"):
await run_subagent.run(prompt="task", allowed_tools=[" tests_missing_tool "], context=ctx)

agent.run.assert_not_called()
32 changes: 31 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
from pydantic import BaseModel

from bub.tools import REGISTRY, model_tools, render_tools_prompt, tool
from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names, tool


class EchoInput(BaseModel):
Expand Down Expand Up @@ -123,3 +123,33 @@ def prompt_two() -> str:

def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
assert render_tools_prompt([]) == ""


def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None:
dotted_name = "tests.resolve_alias"
underscored_name = "tests_with_underscore"
REGISTRY.pop(dotted_name, None)
REGISTRY.pop(underscored_name, None)

@tool(name=dotted_name)
def resolve_alias() -> str:
return "alias"

@tool(name=underscored_name)
def resolve_runtime_name() -> str:
return "runtime"

assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == {
dotted_name,
underscored_name,
}
assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "})
assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name}


def test_resolve_tool_names_rejects_unknown_names() -> None:
with pytest.raises(ValueError, match="tests_missing_tool"):
resolve_tool_names([" tests_missing_tool "])

with pytest.raises(ValueError, match="tests_missing_tool"):
resolve_tool_names(None, exclude={" tests_missing_tool "})
Loading