Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/bub/builtin/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from bub.utils import get_entry_text

current_store: contextvars.ContextVar[TapeStore] = contextvars.ContextVar("current_store")
current_fork_tape: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_fork_tape", default=None)
current_tape_was_reset: contextvars.ContextVar[bool] = contextvars.ContextVar("current_tape_was_reset", default=False)
WORD_PATTERN = re.compile(r"[a-z0-9_/-]+")
MIN_FUZZY_QUERY_LENGTH = 3
MIN_FUZZY_SCORE = 80
Expand All @@ -37,18 +39,31 @@ def __init__(self, parent: AsyncTapeStore | TapeStore) -> None:
def _current(self) -> TapeStore:
return current_store.get(_emtpy_store)

@property
def _fork_tape(self) -> str | None:
return current_fork_tape.get()

@property
def _current_was_reset(self) -> bool:
return current_tape_was_reset.get()

async def list_tapes(self) -> list[str]:
return cast(list[str], await self._parent.list_tapes())

async def reset(self, tape: str) -> None:
self._current.reset(tape)
await self._parent.reset(tape)
if self._current is _emtpy_store or self._fork_tape != tape:
await self._parent.reset(tape)
return
current_tape_was_reset.set(True)

async def fetch_all(self, query: TapeQuery[AsyncTapeStore]) -> Iterable[TapeEntry]:
try:
parent_entries = await self._parent.fetch_all(query)
except Exception:
parent_entries = []
parent_entries: Iterable[TapeEntry] = []
if not (query.tape == self._fork_tape and self._current_was_reset):
try:
parent_entries = await self._parent.fetch_all(query)
except Exception:
parent_entries = []
this_entries: list[TapeEntry] = []
if hasattr(self._current, "read"):
for entry in cast(list[TapeEntry], self._current.read(query.tape) or []):
Expand Down Expand Up @@ -87,11 +102,18 @@ async def append(self, tape: str, entry: TapeEntry) -> None:
async def fork(self, tape: str, merge_back: bool = True) -> AsyncGenerator[None, None]:
store = InMemoryTapeStore()
token = current_store.set(store)
tape_token = current_fork_tape.set(tape)
reset_token = current_tape_was_reset.set(False)
try:
yield
finally:
was_reset = current_tape_was_reset.get()
current_store.reset(token)
current_fork_tape.reset(tape_token)
current_tape_was_reset.reset(reset_token)
if merge_back:
if was_reset:
await self._parent.reset(tape)
entries = store.read(tape)
if entries:
count = len(entries)
Expand Down
60 changes: 59 additions & 1 deletion tests/test_fork_store_merge_back.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import pytest
from republic import TapeEntry
from republic import TapeEntry, TapeQuery
from republic.tape import InMemoryTapeStore

from bub.builtin.store import ForkTapeStore
Expand Down Expand Up @@ -48,3 +48,61 @@ async def test_fork_default_merge_back_is_true() -> None:
entries = parent.read("test-tape")
assert entries is not None
assert len(entries) == 1


@pytest.mark.asyncio
async def test_fork_reset_with_merge_back_false_preserves_parent_entries() -> None:
parent = InMemoryTapeStore()
store = ForkTapeStore(parent)
parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1}))

async with store.fork("test-tape", merge_back=False):
await store.reset("test-tape")
await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2}))

entries = parent.read("test-tape")
assert entries is not None
assert [entry.payload["name"] for entry in entries] == ["before"]


@pytest.mark.asyncio
async def test_fork_reset_with_merge_back_true_replaces_parent_entries() -> None:
parent = InMemoryTapeStore()
store = ForkTapeStore(parent)
parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1}))

async with store.fork("test-tape", merge_back=True):
await store.reset("test-tape")
await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2}))

entries = parent.read("test-tape")
assert entries is not None
assert [entry.payload["name"] for entry in entries] == ["inside"]


@pytest.mark.asyncio
async def test_fork_reset_hides_parent_entries_during_fetch() -> None:
parent = InMemoryTapeStore()
store = ForkTapeStore(parent)
parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1}))

async with store.fork("test-tape", merge_back=False):
await store.reset("test-tape")
await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2}))

query = TapeQuery(tape="test-tape", store=store)
entries = list(await store.fetch_all(query))

assert [entry.payload["name"] for entry in entries] == ["inside"]


@pytest.mark.asyncio
async def test_reset_outside_fork_resets_parent_immediately() -> None:
parent = InMemoryTapeStore()
store = ForkTapeStore(parent)
parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1}))

await store.reset("test-tape")

entries = parent.read("test-tape")
assert entries is None
Loading