Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
import shlex
import time
from collections.abc import Collection
from dataclasses import dataclass
from dataclasses import dataclass, replace
from datetime import UTC, datetime
from functools import cached_property
from pathlib import Path
from typing import Any

from loguru import logger
from republic import LLM, AsyncTapeStore, ToolAutoResult, ToolContext
from republic import LLM, AsyncTapeStore, TapeContext, ToolAutoResult, ToolContext
from republic.tape import InMemoryTapeStore, Tape

from bub.builtin.context import default_tape_context
from bub.builtin.settings import AgentSettings
from bub.builtin.store import ForkTapeStore
from bub.builtin.tape import TapeService
Expand Down Expand Up @@ -46,7 +45,7 @@ def tapes(self) -> TapeService:
if tape_store is None:
tape_store = InMemoryTapeStore()
tape_store = ForkTapeStore(tape_store)
llm = _build_llm(self.settings, tape_store)
llm = _build_llm(self.settings, tape_store, self.framework.build_tape_context())
return TapeService(llm, self.settings.home / "tapes", tape_store)

async def run(
Expand All @@ -62,7 +61,7 @@ async def run(
if not prompt:
return "error: empty prompt"
tape = self.tapes.session_tape(session_id, workspace_from_state(state))
tape.context.state.update(state)
tape.context = replace(tape.context, state=state)
merge_back = not session_id.startswith("temp/")
async with self.tapes.fork_tape(tape.name, merge_back=merge_back):
await self.tapes.ensure_bootstrap_anchor(tape.name)
Expand Down Expand Up @@ -123,6 +122,16 @@ async def _agent_loop(
) -> str:
next_prompt: str | list[dict] = prompt
display_model = model or self.settings.model
await self.tapes.append_event(
tape.name,
"loop.start",
{
"model": display_model,
"prompt": prompt,
"allowed_skills": list(allowed_skills) if allowed_skills else None,
"allowed_tools": list(allowed_tools) if allowed_tools else None,
},
)
for step in range(1, self.settings.max_steps + 1):
start = time.monotonic()
logger.info("loop.step step={} tape={} model={}", step, tape.name, display_model)
Expand Down Expand Up @@ -265,7 +274,7 @@ def _resolve_tool_auto_result(output: ToolAutoResult) -> _ToolAutoOutcome:
return _ToolAutoOutcome(kind="error", error=f"{error_kind}: {output.error.message}")


def _build_llm(settings: AgentSettings, tape_store: AsyncTapeStore) -> LLM:
def _build_llm(settings: AgentSettings, tape_store: AsyncTapeStore, tape_context: TapeContext) -> LLM:
from republic.auth.openai_codex import openai_codex_oauth_resolver

return LLM(
Expand All @@ -276,7 +285,7 @@ def _build_llm(settings: AgentSettings, tape_store: AsyncTapeStore) -> LLM:
api_key_resolver=openai_codex_oauth_resolver(),
tape_store=tape_store,
api_format=settings.api_format,
context=default_tape_context(),
context=tape_context,
verbose=settings.verbose,
)

Expand Down
4 changes: 2 additions & 2 deletions src/bub/builtin/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from republic import TapeContext, TapeEntry


def default_tape_context(state: dict[str, Any] | None = None) -> TapeContext:
def default_tape_context() -> TapeContext:
"""Return the default context selection for Bub."""

return TapeContext(select=_select_messages, state=state or {})
return TapeContext(select=_select_messages)


def _select_messages(entries: Iterable[TapeEntry], _context: TapeContext) -> list[dict[str, Any]]:
Expand Down
6 changes: 6 additions & 0 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import typer
from loguru import logger
from republic import TapeContext
from republic.tape import TapeStore

from bub.builtin.agent import Agent
from bub.builtin.context import default_tape_context
from bub.channels.base import Channel
from bub.channels.message import ChannelMessage, MediaItem
from bub.envelope import content_of, field_of
Expand Down Expand Up @@ -187,3 +189,7 @@ def provide_tape_store(self) -> TapeStore:
from bub.builtin.store import FileTapeStore

return FileTapeStore(directory=self.agent.settings.home / "tapes")

@hookimpl
def build_tape_context(self) -> TapeContext:
return default_tape_context()
5 changes: 4 additions & 1 deletion src/bub/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pluggy
import typer
from loguru import logger
from republic import AsyncTapeStore
from republic import AsyncTapeStore, TapeContext
from republic.tape import TapeStore

from bub.envelope import content_of, field_of, unpack_batch
Expand Down Expand Up @@ -209,3 +209,6 @@ def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) ->
for result in reversed(self._hook_runtime.call_many_sync("system_prompt", prompt=prompt, state=state))
if result
)

def build_tape_context(self) -> TapeContext:
return self._hook_runtime.call_first_sync("build_tape_context")
7 changes: 6 additions & 1 deletion src/bub/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any

import pluggy
from republic import AsyncTapeStore
from republic import AsyncTapeStore, TapeContext
from republic.tape import TapeStore

from bub.types import Envelope, MessageHandler, State
Expand Down Expand Up @@ -93,3 +93,8 @@ def provide_tape_store(self) -> TapeStore | AsyncTapeStore:
def provide_channels(self, message_handler: MessageHandler) -> list[Channel]:
"""Provide a list of channels for receiving messages."""
raise NotImplementedError

@hookspec(firstresult=True)
def build_tape_context(self) -> TapeContext:
"""Build a tape context for the current session, to be used to build context messages."""
raise NotImplementedError
7 changes: 3 additions & 4 deletions tests/test_builtin_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
import republic.auth.openai_codex as openai_codex
from republic import ToolAutoResult
from republic import TapeContext, ToolAutoResult

import bub.builtin.agent as agent_module
from bub.builtin.agent import Agent
Expand All @@ -25,12 +25,11 @@ def __init__(self, *args: object, **kwargs: object) -> None:

monkeypatch.setattr(agent_module, "LLM", FakeLLM)
monkeypatch.setattr(openai_codex, "openai_codex_oauth_resolver", lambda: resolver)
monkeypatch.setattr(agent_module, "default_tape_context", lambda: "ctx")

settings = AgentSettings(model="openai:gpt-5-codex", api_key=None, api_base=None)
tape_store = object()

agent_module._build_llm(settings, tape_store)
agent_module._build_llm(settings, tape_store, "ctx")

assert captured["args"] == ("openai:gpt-5-codex",)
assert captured["kwargs"]["api_key"] is None
Expand Down Expand Up @@ -81,7 +80,7 @@ def __init__(self, fork_capture: _ForkCapture) -> None:
def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
tape = MagicMock()
tape.name = "test-tape"
tape.context.state = {}
tape.context = TapeContext(state={})

async def fake_run_tools_async(**kwargs: Any) -> ToolAutoResult:
self.run_tools_model = kwargs.get("model")
Expand Down
Loading