diff --git a/src/bub/builtin/store.py b/src/bub/builtin/store.py index 0a2a1ea3..3403fd18 100644 --- a/src/bub/builtin/store.py +++ b/src/bub/builtin/store.py @@ -20,6 +20,8 @@ from bub.utils import get_entry_text current_store: contextvars.ContextVar[TapeStore] = contextvars.ContextVar("current_store") +current_fork_tape: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_fork_tape", default=None) +current_tape_was_reset: contextvars.ContextVar[bool] = contextvars.ContextVar("current_tape_was_reset", default=False) WORD_PATTERN = re.compile(r"[a-z0-9_/-]+") MIN_FUZZY_QUERY_LENGTH = 3 MIN_FUZZY_SCORE = 80 @@ -37,18 +39,31 @@ def __init__(self, parent: AsyncTapeStore | TapeStore) -> None: def _current(self) -> TapeStore: return current_store.get(_emtpy_store) + @property + def _fork_tape(self) -> str | None: + return current_fork_tape.get() + + @property + def _current_was_reset(self) -> bool: + return current_tape_was_reset.get() + async def list_tapes(self) -> list[str]: return cast(list[str], await self._parent.list_tapes()) async def reset(self, tape: str) -> None: self._current.reset(tape) - await self._parent.reset(tape) + if self._current is _emtpy_store or self._fork_tape != tape: + await self._parent.reset(tape) + return + current_tape_was_reset.set(True) async def fetch_all(self, query: TapeQuery[AsyncTapeStore]) -> Iterable[TapeEntry]: - try: - parent_entries = await self._parent.fetch_all(query) - except Exception: - parent_entries = [] + parent_entries: Iterable[TapeEntry] = [] + if not (query.tape == self._fork_tape and self._current_was_reset): + try: + parent_entries = await self._parent.fetch_all(query) + except Exception: + parent_entries = [] this_entries: list[TapeEntry] = [] if hasattr(self._current, "read"): for entry in cast(list[TapeEntry], self._current.read(query.tape) or []): @@ -87,11 +102,18 @@ async def append(self, tape: str, entry: TapeEntry) -> None: async def fork(self, tape: str, merge_back: bool = True) -> AsyncGenerator[None, None]: store = InMemoryTapeStore() token = current_store.set(store) + tape_token = current_fork_tape.set(tape) + reset_token = current_tape_was_reset.set(False) try: yield finally: + was_reset = current_tape_was_reset.get() current_store.reset(token) + current_fork_tape.reset(tape_token) + current_tape_was_reset.reset(reset_token) if merge_back: + if was_reset: + await self._parent.reset(tape) entries = store.read(tape) if entries: count = len(entries) diff --git a/tests/test_fork_store_merge_back.py b/tests/test_fork_store_merge_back.py index 280a08d7..418c83f1 100644 --- a/tests/test_fork_store_merge_back.py +++ b/tests/test_fork_store_merge_back.py @@ -1,7 +1,7 @@ from __future__ import annotations import pytest -from republic import TapeEntry +from republic import TapeEntry, TapeQuery from republic.tape import InMemoryTapeStore from bub.builtin.store import ForkTapeStore @@ -48,3 +48,61 @@ async def test_fork_default_merge_back_is_true() -> None: entries = parent.read("test-tape") assert entries is not None assert len(entries) == 1 + + +@pytest.mark.asyncio +async def test_fork_reset_with_merge_back_false_preserves_parent_entries() -> None: + parent = InMemoryTapeStore() + store = ForkTapeStore(parent) + parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) + + async with store.fork("test-tape", merge_back=False): + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + + entries = parent.read("test-tape") + assert entries is not None + assert [entry.payload["name"] for entry in entries] == ["before"] + + +@pytest.mark.asyncio +async def test_fork_reset_with_merge_back_true_replaces_parent_entries() -> None: + parent = InMemoryTapeStore() + store = ForkTapeStore(parent) + parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) + + async with store.fork("test-tape", merge_back=True): + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + + entries = parent.read("test-tape") + assert entries is not None + assert [entry.payload["name"] for entry in entries] == ["inside"] + + +@pytest.mark.asyncio +async def test_fork_reset_hides_parent_entries_during_fetch() -> None: + parent = InMemoryTapeStore() + store = ForkTapeStore(parent) + parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) + + async with store.fork("test-tape", merge_back=False): + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + + query = TapeQuery(tape="test-tape", store=store) + entries = list(await store.fetch_all(query)) + + assert [entry.payload["name"] for entry in entries] == ["inside"] + + +@pytest.mark.asyncio +async def test_reset_outside_fork_resets_parent_immediately() -> None: + parent = InMemoryTapeStore() + store = ForkTapeStore(parent) + parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) + + await store.reset("test-tape") + + entries = parent.read("test-tape") + assert entries is None