diff --git a/bun.lock b/bun.lock index 6140c3497a5..61f1a895f12 100644 --- a/bun.lock +++ b/bun.lock @@ -4876,7 +4876,7 @@ "@aws-sdk/client-sts/@aws-sdk/util-user-agent-node": ["@aws-sdk/util-user-agent-node@3.782.0", "", { "dependencies": { "@aws-sdk/middleware-user-agent": "3.782.0", "@aws-sdk/types": "3.775.0", "@smithy/node-config-provider": "^4.0.2", "@smithy/types": "^4.2.0", "tslib": "^2.6.2" }, "peerDependencies": { "aws-crt": ">=1.0.0" }, "optionalPeers": ["aws-crt"] }, "sha512-dMFkUBgh2Bxuw8fYZQoH/u3H4afQ12VSkzEi//qFiDTwbKYq+u+RYjc8GLDM6JSK1BShMu5AVR7HD4ap1TYUnA=="], - "@aws-sdk/credential-provider-cognito-identity/@aws-sdk/nested-clients": ["@aws-sdk/nested-clients@3.996.8", "", { "dependencies": { "@aws-crypto/sha256-browser": "5.2.0", "@aws-crypto/sha256-js": "5.2.0", "@aws-sdk/core": "^3.973.19", "@aws-sdk/middleware-host-header": "^3.972.7", "@aws-sdk/middleware-logger": "^3.972.7", "@aws-sdk/middleware-recursion-detection": "^3.972.7", "@aws-sdk/middleware-user-agent": "^3.972.20", "@aws-sdk/region-config-resolver": "^3.972.7", "@aws-sdk/types": "^3.973.5", "@aws-sdk/util-endpoints": "^3.996.4", "@aws-sdk/util-user-agent-browser": "^3.972.7", "@aws-sdk/util-user-agent-node": "^3.973.5", "@smithy/config-resolver": "^4.4.10", "@smithy/core": "^3.23.9", "@smithy/fetch-http-handler": "^5.3.13", "@smithy/hash-node": "^4.2.11", "@smithy/invalid-dependency": "^4.2.11", "@smithy/middleware-content-length": "^4.2.11", "@smithy/middleware-endpoint": "^4.4.23", "@smithy/middleware-retry": "^4.4.40", "@smithy/middleware-serde": "^4.2.12", "@smithy/middleware-stack": "^4.2.11", "@smithy/node-config-provider": "^4.3.11", "@smithy/node-http-handler": "^4.4.14", "@smithy/protocol-http": "^5.3.11", "@smithy/smithy-client": "^4.12.3", "@smithy/types": "^4.13.0", "@smithy/url-parser": "^4.2.11", "@smithy/util-base64": "^4.3.2", "@smithy/util-body-length-browser": "^4.2.2", "@smithy/util-body-length-node": "^4.2.3", "@smithy/util-defaults-mode-browser": "^4.3.39", "@smithy/util-defaults-mode-node": "^4.2.42", "@smithy/util-endpoints": "^3.3.2", "@smithy/util-middleware": "^4.2.11", "@smithy/util-retry": "^4.2.11", "@smithy/util-utf8": "^4.2.2", "tslib": "^2.6.2" } }, "sha512-6HlLm8ciMW8VzfB80kfIx16PBA9lOa9Dl+dmCBi78JDhvGlx3I7Rorwi5PpVRkL31RprXnYna3yBf6UKkD/PqA=="], + "@aws-sdk/credential-provider-cognito-identity/@aws-sdk/types": ["@aws-sdk/types@3.973.5", "", { "dependencies": { "@smithy/types": "^4.13.0", "tslib": "^2.6.2" } }, "sha512-hl7BGwDCWsjH8NkZfx+HgS7H2LyM2lTMAI7ba9c8O0KqdBLTdNJivsHpqjg9rNlAlPyREb6DeDRXUl0s8uFdmQ=="], diff --git a/packages/app/src/components/session/session-context-tab.tsx b/packages/app/src/components/session/session-context-tab.tsx index 9aa101bdb9a..970d548e49a 100644 --- a/packages/app/src/components/session/session-context-tab.tsx +++ b/packages/app/src/components/session/session-context-tab.tsx @@ -3,6 +3,7 @@ import type { JSX } from "solid-js" import { useSync } from "@/context/sync" import { checksum } from "@opencode-ai/util/encode" import { findLast } from "@opencode-ai/util/array" +import { sortMessages } from "@opencode-ai/util/message" import { same } from "@/utils/same" import { Icon } from "@opencode-ai/ui/icon" import { Accordion } from "@opencode-ai/ui/accordion" @@ -100,7 +101,7 @@ export function SessionContextTab() { () => { const id = params.id if (!id) return emptyMessages - return (sync.data.message[id] ?? []) as Message[] + return sortMessages((sync.data.message[id] ?? []) as Message[]) }, emptyMessages, { equals: same }, diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx index c25463d756a..c7a0c6e1391 100644 --- a/packages/app/src/pages/session.tsx +++ b/packages/app/src/pages/session.tsx @@ -411,7 +411,8 @@ export default function Page() { () => { const revert = revertMessageID() if (!revert) return userMessages() - return userMessages().filter((m) => m.id < revert) + const idx = userMessages().findIndex((m) => m.id === revert) + return idx >= 0 ? userMessages().slice(0, idx) : userMessages() }, emptyUserMessages, { @@ -569,7 +570,7 @@ export default function Page() { ) return } - const at = list.findIndex((item) => item.id > next.id) + const at = list.findIndex((item) => item.id.localeCompare(next.id) > 0) if (at >= 0) { globalSync.set("project", [...list.slice(0, at), next, ...list.slice(at)]) return @@ -1245,7 +1246,8 @@ export default function Page() { const sessionID = params.id if (!sessionID || ui.restoring) return - const next = userMessages().find((item) => item.id > id) + const idx = userMessages().findIndex((m) => m.id === id) + const next = idx >= 0 ? userMessages()[idx + 1] : undefined setUi("restoring", id) const task = !next @@ -1273,9 +1275,8 @@ export default function Page() { const rolled = createMemo(() => { const id = revertMessageID() if (!id) return [] - return userMessages() - .filter((item) => item.id >= id) - .map((item) => ({ id: item.id, text: line(item.id) })) + const idx = userMessages().findIndex((m) => m.id === id) + return (idx >= 0 ? userMessages().slice(idx) : []).map((item) => ({ id: item.id, text: line(item.id) })) }) const actions = { fork, revert } diff --git a/packages/app/src/pages/session/message-timeline.tsx b/packages/app/src/pages/session/message-timeline.tsx index 50f9b452a13..366d73d1224 100644 --- a/packages/app/src/pages/session/message-timeline.tsx +++ b/packages/app/src/pages/session/message-timeline.tsx @@ -13,7 +13,7 @@ import { SessionTurn } from "@opencode-ai/ui/session-turn" import { ScrollView } from "@opencode-ai/ui/scroll-view" import type { AssistantMessage, Message as MessageType, Part, TextPart, UserMessage } from "@opencode-ai/sdk/v2" import { showToast } from "@opencode-ai/ui/toast" -import { Binary } from "@opencode-ai/util/binary" +import { sortMessages } from "@opencode-ai/util/message" import { getFilename } from "@opencode-ai/util/path" import { shouldMarkBoundaryGesture, normalizeWheelDelta } from "@/pages/session/message-gesture" import { SessionContextUsage } from "@/components/session-context-usage" @@ -227,7 +227,7 @@ export function MessageTimeline(props: { const sessionMessages = createMemo(() => { const id = sessionID() if (!id) return emptyMessages - return sync.data.message[id] ?? emptyMessages + return sortMessages(sync.data.message[id] ?? emptyMessages) }) const pending = createMemo(() => sessionMessages().findLast( @@ -277,8 +277,7 @@ export function MessageTimeline(props: { const parentID = pending()?.parentID if (parentID) { const messages = sessionMessages() - const result = Binary.search(messages, parentID, (message) => message.id) - const message = result.found ? messages[result.index] : messages.find((item) => item.id === parentID) + const message = messages.find((item) => item.id === parentID) if (message && message.role === "user") return message.id } @@ -755,8 +754,11 @@ export function MessageTimeline(props: { const queued = createMemo(() => { if (active()) return false const activeID = activeMessageID() - if (activeID) return messageID > activeID - return false + if (!activeID) return false + const ids = rendered() + const activeIdx = ids.indexOf(activeID) + if (activeIdx === -1) return false + return ids.indexOf(messageID) > activeIdx }) const comments = createMemo(() => messageComments(sync.data.part[messageID] ?? []), [], { equals: (a, b) => JSON.stringify(a) === JSON.stringify(b), diff --git a/packages/app/src/pages/session/use-session-commands.tsx b/packages/app/src/pages/session/use-session-commands.tsx index 6799504ca6a..35fee53537a 100644 --- a/packages/app/src/pages/session/use-session-commands.tsx +++ b/packages/app/src/pages/session/use-session-commands.tsx @@ -18,6 +18,7 @@ import { DialogSelectMcp } from "@/components/dialog-select-mcp" import { DialogFork } from "@/components/dialog-fork" import { showToast } from "@opencode-ai/ui/toast" import { findLast } from "@opencode-ai/util/array" +import { sortMessages } from "@opencode-ai/util/message" import { extractPromptFromParts } from "@/utils/prompt" import { UserMessage } from "@opencode-ai/sdk/v2" import { useSessionLayout } from "@/pages/session/session-layout" @@ -54,7 +55,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => { const idle = { type: "idle" as const } const status = createMemo(() => sync.data.session_status[params.id ?? ""] ?? idle) - const messages = createMemo(() => (params.id ? (sync.data.message[params.id] ?? []) : [])) + const messages = createMemo(() => sortMessages(params.id ? (sync.data.message[params.id] ?? []) : [])) const userMessages = createMemo(() => messages().filter((m) => m.role === "user") as UserMessage[]) const visibleUserMessages = createMemo(() => { const revert = info()?.revert?.messageID diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index 0879fe87fd3..f5b7a215401 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -252,9 +252,11 @@ export namespace Session { }) const msgs = await messages({ sessionID: input.sessionID }) const idMap = new Map() + const cutoff = input.messageID ? msgs.findIndex((msg) => msg.info.id === input.messageID) : -1 - for (const msg of msgs) { - if (input.messageID && msg.info.id >= input.messageID) break + for (let i = 0; i < msgs.length; i++) { + if (cutoff >= 0 && i >= cutoff) break + const msg = msgs[i] const newID = MessageID.ascending() idMap.set(msg.info.id, newID) diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 03ccb44c1ad..53579453c90 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -616,6 +616,14 @@ export namespace MessageV2 { ) { continue } + // Skip incomplete assistant messages (no finish, no error, and no meaningful parts) + if ( + !msg.info.finish && + !msg.info.error && + !msg.parts.some((part) => part.type !== "step-start") + ) { + continue + } const assistantMessage: UIMessage = { id: msg.info.id, role: "assistant", diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index 171c4b448fd..9c7ffb28bff 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -300,15 +300,16 @@ export namespace SessionPrompt { let msgs = await MessageV2.filterCompacted(MessageV2.stream(sessionID)) let lastUser: MessageV2.User | undefined - let lastAssistant: MessageV2.Assistant | undefined let lastFinished: MessageV2.Assistant | undefined + let finishedIdx = -1 let tasks: (MessageV2.CompactionPart | MessageV2.SubtaskPart)[] = [] for (let i = msgs.length - 1; i >= 0; i--) { const msg = msgs[i] if (!lastUser && msg.info.role === "user") lastUser = msg.info as MessageV2.User - if (!lastAssistant && msg.info.role === "assistant") lastAssistant = msg.info as MessageV2.Assistant - if (!lastFinished && msg.info.role === "assistant" && msg.info.finish) + if (!lastFinished && msg.info.role === "assistant" && msg.info.finish) { lastFinished = msg.info as MessageV2.Assistant + finishedIdx = i + } if (lastUser && lastFinished) break const task = msg.parts.filter((part) => part.type === "compaction" || part.type === "subtask") if (task && !lastFinished) { @@ -317,11 +318,7 @@ export namespace SessionPrompt { } if (!lastUser) throw new Error("No user message found in stream. This should never happen.") - if ( - lastAssistant?.finish && - !["tool-calls", "unknown"].includes(lastAssistant.finish) && - lastUser.id < lastAssistant.id - ) { + if (shouldExitLoop(lastUser, lastFinished)) { log.info("exiting loop", { sessionID }) break } @@ -631,8 +628,9 @@ export namespace SessionPrompt { // Ephemerally wrap queued user messages with a reminder to stay on track if (step > 1 && lastFinished) { - for (const msg of msgs) { - if (msg.info.role !== "user" || msg.info.id <= lastFinished.id) continue + for (let i = 0; i < msgs.length; i++) { + const msg = msgs[i] + if (!shouldWrapSystemReminder(msg.info, i, lastFinished, finishedIdx)) continue for (const part of msg.parts) { if (part.type !== "text" || part.ignored || part.synthetic) continue if (!part.text.trim()) continue @@ -1966,4 +1964,26 @@ NOTE: At any point in time through this workflow you should feel free to ask the return Session.setTitle({ sessionID: input.session.id, title }) } } + + export function shouldExitLoop( + lastUser: MessageV2.User | undefined, + lastAssistant: MessageV2.Assistant | undefined, + ): boolean { + if (!lastUser) return false + if (!lastAssistant?.finish) return false + if (["tool-calls", "unknown"].includes(lastAssistant.finish)) return false + if (!lastAssistant.parentID) return true + return lastAssistant.parentID === lastUser.id + } + + export function shouldWrapSystemReminder( + msg: MessageV2.User | MessageV2.Assistant, + idx: number, + lastFinished: MessageV2.Assistant | undefined, + finishedIdx: number, + ): boolean { + if (msg.role !== "user") return false + if (!lastFinished) return false + return idx > finishedIdx + } } diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index c5c9edbbdfa..6086f75eec7 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -59,7 +59,8 @@ export namespace SessionRevert { revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track()) await Snapshot.revert(patches) if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) - const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID) + const idx = all.findIndex((msg) => msg.info.id === revert!.messageID) + const rangeMessages = idx >= 0 ? all.slice(idx) : all const diffs = await SessionSummary.computeDiff({ messages: rangeMessages }) await Storage.write(["session_diff", input.sessionID], diffs) Bus.publish(Session.Event.Diff, { @@ -96,21 +97,27 @@ export namespace SessionRevert { const preserve = [] as MessageV2.WithParts[] const remove = [] as MessageV2.WithParts[] let target: MessageV2.WithParts | undefined - for (const msg of msgs) { - if (msg.info.id < messageID) { - preserve.push(msg) - continue - } - if (msg.info.id > messageID) { + const idx = msgs.findIndex((msg) => msg.info.id === messageID) + if (idx < 0) { + preserve.push(...msgs) + } else { + for (let i = 0; i < msgs.length; i++) { + const msg = msgs[i]! + if (i < idx) { + preserve.push(msg) + continue + } + if (i > idx) { + remove.push(msg) + continue + } + if (session.revert.partID) { + preserve.push(msg) + target = msg + continue + } remove.push(msg) - continue - } - if (session.revert.partID) { - preserve.push(msg) - target = msg - continue } - remove.push(msg) } for (const msg of remove) { Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run()) diff --git a/packages/opencode/test/session/fixtures/skewed-messages.ts b/packages/opencode/test/session/fixtures/skewed-messages.ts new file mode 100644 index 00000000000..e30b4f95ec3 --- /dev/null +++ b/packages/opencode/test/session/fixtures/skewed-messages.ts @@ -0,0 +1,84 @@ +import { Identifier } from "../../../src/id/id" +import { MessageV2 } from "../../../src/session/message-v2" +import { MessageID, SessionID } from "../../../src/session/schema" +import { ModelID, ProviderID } from "../../../src/provider/schema" + +export function makeUser( + id: MessageID, + opts?: Partial, +): MessageV2.User { + return { + id, + sessionID: SessionID.make("test-session"), + role: "user", + time: { created: Date.now() }, + agent: "default", + model: { + providerID: ProviderID.openai, + modelID: ModelID.make("gpt-4"), + }, + ...opts, + } +} + +export function makeAssistant( + id: MessageID, + parentID: MessageID, + opts?: Partial, +): MessageV2.Assistant { + return { + id, + sessionID: SessionID.make("test-session"), + role: "assistant", + parentID, + time: { created: Date.now() }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.openai, + mode: "default", + agent: "default", + path: { + cwd: "/tmp", + root: "/tmp", + }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { + read: 0, + write: 0, + }, + }, + finish: "stop", + ...opts, + } +} + +export function aheadPair(): { + user: MessageV2.User + assistant: MessageV2.Assistant +} { + const now = Date.now() + const userID = MessageID.make(Identifier.create("message", false, now + 60_000)) + const assistantID = MessageID.make(Identifier.create("message", false, now)) + + return { + user: makeUser(userID), + assistant: makeAssistant(assistantID, userID), + } +} + +export function behindPair(): { + user: MessageV2.User + assistant: MessageV2.Assistant +} { + const now = Date.now() + const userID = MessageID.make(Identifier.create("message", false, now - 60_000)) + const assistantID = MessageID.make(Identifier.create("message", false, now)) + + return { + user: makeUser(userID), + assistant: makeAssistant(assistantID, userID), + } +} diff --git a/packages/opencode/test/session/message-v2.test.ts b/packages/opencode/test/session/message-v2.test.ts index e9c6cb729bb..5fa371c7a99 100644 --- a/packages/opencode/test/session/message-v2.test.ts +++ b/packages/opencode/test/session/message-v2.test.ts @@ -643,6 +643,54 @@ describe("session.message-v2.toModelMessage", () => { ]) }) + test("toModelMessages skips assistant messages with no finish and no error", () => { + const userID = "m-user" + const assistantID = "m-assistant" + + const input: MessageV2.WithParts[] = [ + { + info: userInfo(userID), + parts: [ + { + ...basePart(userID, "u1"), + type: "text", + text: "hello", + }, + ] as MessageV2.Part[], + }, + { + info: assistantInfo(assistantID, userID), + parts: [ + { + ...basePart(assistantID, "a1"), + type: "step-start", + }, + ] as MessageV2.Part[], + }, + { + info: userInfo("m-user-2"), + parts: [ + { + ...basePart("m-user-2", "u2"), + type: "text", + text: "follow up", + }, + ] as MessageV2.Part[], + }, + ] + + expect(MessageV2.toModelMessages(input, model)).toStrictEqual([ + { + role: "user", + content: [{ type: "text", text: "hello" }], + }, + { + role: "user", + content: [{ type: "text", text: "follow up" }], + }, + ]) + }) + test("splits assistant messages on step-start boundaries", () => { const assistantID = "m-assistant" diff --git a/packages/opencode/test/session/prompt.test.ts b/packages/opencode/test/session/prompt.test.ts index 3986271dab9..6dcc34451fe 100644 --- a/packages/opencode/test/session/prompt.test.ts +++ b/packages/opencode/test/session/prompt.test.ts @@ -8,6 +8,8 @@ import { MessageV2 } from "../../src/session/message-v2" import { SessionPrompt } from "../../src/session/prompt" import { Log } from "../../src/util/log" import { tmpdir } from "../fixture/fixture" +import { MessageID } from "../../src/session/schema" +import { behindPair, makeUser, makeAssistant } from "./fixtures/skewed-messages" Log.init({ print: false }) @@ -210,3 +212,57 @@ describe("session.prompt agent variant", () => { } }) }) + +describe("shouldExitLoop", () => { + const user = makeUser(MessageID.make("msg-user-1")) + const assistant = (parentID: string | undefined, finish?: string) => + makeAssistant(MessageID.make("msg-asst-1"), MessageID.make(parentID ?? "msg-user-1"), { + finish, + parentID: parentID ? MessageID.make(parentID) : undefined, + }) + + test("normal exit: parentID matches, finish=end_turn → true", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-user-1", "end_turn"))).toBe(true) + }) + + test("clock-skew exit: parentID matches, finish=stop → true", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-user-1", "stop"))).toBe(true) + }) + + test("tool-calls: finish=tool-calls → false", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-user-1", "tool-calls"))).toBe(false) + }) + + test("unknown: finish=unknown → false", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-user-1", "unknown"))).toBe(false) + }) + + test("no assistant: lastAssistant=undefined → false", () => { + expect(SessionPrompt.shouldExitLoop(user, undefined)).toBe(false) + }) + + test("no finish: finish=undefined → false", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-user-1", undefined))).toBe(false) + }) + + test("parentID mismatch: assistant.parentID !== user.id → false", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant("msg-other-user", "end_turn"))).toBe(false) + }) + + test("no user: lastUser=undefined → false", () => { + expect(SessionPrompt.shouldExitLoop(undefined, assistant("msg-user-1", "end_turn"))).toBe(false) + }) + + test("should return true when assistant has missing parentID (fail-safe)", () => { + expect(SessionPrompt.shouldExitLoop(user, assistant(undefined, "end_turn"))).toBe(true) + }) +}) + +describe("system-reminder wrapping", () => { + test("wraps queued user messages based on position, not ID ordering", () => { + const { user, assistant } = behindPair() + expect(user.id < assistant.id).toBe(true) + expect(SessionPrompt.shouldWrapSystemReminder(user, 2, assistant, 1)).toBe(true) + expect(SessionPrompt.shouldWrapSystemReminder(user, 1, assistant, 1)).toBe(false) + }) +}) diff --git a/packages/opencode/test/session/revert-compact.test.ts b/packages/opencode/test/session/revert-compact.test.ts index fb37a3a8dca..8744bd7d73b 100644 --- a/packages/opencode/test/session/revert-compact.test.ts +++ b/packages/opencode/test/session/revert-compact.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, test, beforeEach, afterEach } from "bun:test" +import { describe, expect, test, beforeEach, afterEach, spyOn } from "bun:test" import path from "path" import { Session } from "../../src/session" import { ModelID, ProviderID } from "../../src/provider/schema" @@ -8,12 +8,289 @@ import { MessageV2 } from "../../src/session/message-v2" import { Log } from "../../src/util/log" import { Instance } from "../../src/project/instance" import { MessageID, PartID } from "../../src/session/schema" +import { Identifier } from "../../src/id/id" +import { SessionSummary } from "../../src/session/summary" import { tmpdir } from "../fixture/fixture" const projectRoot = path.join(__dirname, "../..") Log.init({ print: false }) describe("revert + compact workflow", () => { + test("cleanup preserves messages before target even when their IDs sort after target", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sessionID = session.id + const now = Date.now() + + const userA = await Session.updateMessage({ + id: MessageID.make(Identifier.create("message", false, now + 60_000)), + role: "user", + sessionID, + agent: "default", + model: { + providerID: ProviderID.make("openai"), + modelID: ModelID.make("gpt-4"), + }, + time: { + created: now + 1, + }, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: userA.id, + sessionID, + type: "text", + text: "A", + }) + + const assistantB: MessageV2.Assistant = { + id: MessageID.ascending(), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { + cwd: tmp.path, + root: tmp.path, + }, + cost: 0, + tokens: { + output: 0, + input: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.make("openai"), + parentID: userA.id, + time: { + created: now + 2, + }, + finish: "end_turn", + } + await Session.updateMessage(assistantB) + await Session.updatePart({ + id: PartID.ascending(), + messageID: assistantB.id, + sessionID, + type: "text", + text: "B", + }) + + const userC = await Session.updateMessage({ + id: MessageID.ascending(), + role: "user", + sessionID, + agent: "default", + model: { + providerID: ProviderID.make("openai"), + modelID: ModelID.make("gpt-4"), + }, + time: { + created: now + 3, + }, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: userC.id, + sessionID, + type: "text", + text: "C", + }) + + const assistantD: MessageV2.Assistant = { + id: MessageID.ascending(), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { + cwd: tmp.path, + root: tmp.path, + }, + cost: 0, + tokens: { + output: 0, + input: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.make("openai"), + parentID: userC.id, + time: { + created: now + 4, + }, + finish: "end_turn", + } + await Session.updateMessage(assistantD) + await Session.updatePart({ + id: PartID.ascending(), + messageID: assistantD.id, + sessionID, + type: "text", + text: "D", + }) + + await SessionRevert.revert({ + sessionID, + messageID: userC.id, + }) + + const info = await Session.get(sessionID) + await SessionRevert.cleanup(info) + + const ids = (await Session.messages({ sessionID })).map((item) => item.info.id) + expect(ids).toEqual([userA.id, assistantB.id]) + + await Session.remove(sessionID) + }, + }) + }) + + test("revert range includes target and trailing messages regardless of ID ordering", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sessionID = session.id + const now = Date.now() + + const userA = await Session.updateMessage({ + id: MessageID.make(Identifier.create("message", false, now + 60_000)), + role: "user", + sessionID, + agent: "default", + model: { + providerID: ProviderID.make("openai"), + modelID: ModelID.make("gpt-4"), + }, + time: { + created: now + 1, + }, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: userA.id, + sessionID, + type: "text", + text: "A", + }) + + const assistantB: MessageV2.Assistant = { + id: MessageID.ascending(), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { + cwd: tmp.path, + root: tmp.path, + }, + cost: 0, + tokens: { + output: 0, + input: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.make("openai"), + parentID: userA.id, + time: { + created: now + 2, + }, + finish: "end_turn", + } + await Session.updateMessage(assistantB) + await Session.updatePart({ + id: PartID.ascending(), + messageID: assistantB.id, + sessionID, + type: "text", + text: "B", + }) + + const userC = await Session.updateMessage({ + id: MessageID.ascending(), + role: "user", + sessionID, + agent: "default", + model: { + providerID: ProviderID.make("openai"), + modelID: ModelID.make("gpt-4"), + }, + time: { + created: now + 3, + }, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: userC.id, + sessionID, + type: "text", + text: "C", + }) + + const assistantD: MessageV2.Assistant = { + id: MessageID.ascending(), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { + cwd: tmp.path, + root: tmp.path, + }, + cost: 0, + tokens: { + output: 0, + input: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.make("openai"), + parentID: userC.id, + time: { + created: now + 4, + }, + finish: "end_turn", + } + await Session.updateMessage(assistantD) + await Session.updatePart({ + id: PartID.ascending(), + messageID: assistantD.id, + sessionID, + type: "text", + text: "D", + }) + + const spy = spyOn(SessionSummary, "computeDiff").mockResolvedValue([]) + + try { + await SessionRevert.revert({ + sessionID, + messageID: userC.id, + }) + + expect(spy).toHaveBeenCalledTimes(1) + const input = spy.mock.calls[0]?.[0] + const ids = input?.messages.map((item) => item.info.id) + expect(ids).toEqual([userC.id, assistantD.id]) + } finally { + spy.mockRestore() + } + + await Session.remove(sessionID) + }, + }) + }) + test("should properly handle compact command after revert", async () => { await using tmp = await tmpdir({ git: true }) await Instance.provide({ @@ -22,6 +299,7 @@ describe("revert + compact workflow", () => { // Create a session const session = await Session.create({}) const sessionID = session.id + const now = Date.now() // Create a user message const userMsg1 = await Session.updateMessage({ @@ -34,7 +312,7 @@ describe("revert + compact workflow", () => { modelID: ModelID.make("gpt-4"), }, time: { - created: Date.now(), + created: now + 1, }, }) @@ -69,7 +347,7 @@ describe("revert + compact workflow", () => { providerID: ProviderID.make("openai"), parentID: userMsg1.id, time: { - created: Date.now(), + created: now + 2, }, finish: "end_turn", } @@ -95,7 +373,7 @@ describe("revert + compact workflow", () => { modelID: ModelID.make("gpt-4"), }, time: { - created: Date.now(), + created: now + 3, }, }) @@ -129,7 +407,7 @@ describe("revert + compact workflow", () => { providerID: ProviderID.make("openai"), parentID: userMsg2.id, time: { - created: Date.now(), + created: now + 4, }, finish: "end_turn", } @@ -198,6 +476,7 @@ describe("revert + compact workflow", () => { // Create a session const session = await Session.create({}) const sessionID = session.id + const now = Date.now() // Create initial messages const userMsg = await Session.updateMessage({ @@ -210,7 +489,7 @@ describe("revert + compact workflow", () => { modelID: ModelID.make("gpt-4"), }, time: { - created: Date.now(), + created: now, }, }) @@ -243,7 +522,7 @@ describe("revert + compact workflow", () => { providerID: ProviderID.make("openai"), parentID: userMsg.id, time: { - created: Date.now(), + created: now + 1, }, finish: "end_turn", } diff --git a/packages/opencode/test/session/session.test.ts b/packages/opencode/test/session/session.test.ts index 23325862233..b8191d44d5c 100644 --- a/packages/opencode/test/session/session.test.ts +++ b/packages/opencode/test/session/session.test.ts @@ -4,6 +4,7 @@ import { Session } from "../../src/session" import { Bus } from "../../src/bus" import { Log } from "../../src/util/log" import { Instance } from "../../src/project/instance" +import { Identifier } from "../../src/id/id" import { MessageV2 } from "../../src/session/message-v2" import { MessageID, PartID } from "../../src/session/schema" diff --git a/packages/ui/src/components/session-turn.tsx b/packages/ui/src/components/session-turn.tsx index fda02cab45b..fa22f534dae 100644 --- a/packages/ui/src/components/session-turn.tsx +++ b/packages/ui/src/components/session-turn.tsx @@ -3,7 +3,7 @@ import type { SessionStatus } from "@opencode-ai/sdk/v2" import { useData } from "../context" import { useFileComponent } from "../context/file" -import { Binary } from "@opencode-ai/util/binary" +import { sortMessages, selectAssistants } from "@opencode-ai/util/message" import { getDirectory, getFilename } from "@opencode-ai/util/path" import { createEffect, createMemo, createSignal, For, on, ParentProps, Show } from "solid-js" import { Dynamic } from "solid-js/web" @@ -166,13 +166,11 @@ export function SessionTurn( const emptyDiffs: FileDiff[] = [] const idle = { type: "idle" as const } - const allMessages = createMemo(() => list(data.store.message?.[props.sessionID], emptyMessages)) + const allMessages = createMemo(() => sortMessages(list(data.store.message?.[props.sessionID], emptyMessages))) const messageIndex = createMemo(() => { - const messages = allMessages() ?? emptyMessages - const result = Binary.search(messages, props.messageID, (m) => m.id) - - const index = result.found ? result.index : messages.findIndex((m) => m.id === props.messageID) + const messages = allMessages() + const index = messages.findIndex((m) => m.id === props.messageID) if (index < 0) return -1 const msg = messages[index] @@ -203,9 +201,8 @@ export function SessionTurn( const pendingUser = createMemo(() => { const item = pending() if (!item?.parentID) return - const messages = allMessages() ?? emptyMessages - const result = Binary.search(messages, item.parentID, (m) => m.id) - const msg = result.found ? messages[result.index] : messages.find((m) => m.id === item.parentID) + const messages = allMessages() + const msg = messages.find((m) => m.id === item.parentID) if (!msg || msg.role !== "user") return return msg }) @@ -220,12 +217,14 @@ export function SessionTurn( const queued = createMemo(() => { if (typeof props.queued === "boolean") return props.queued - const id = message()?.id - if (!id) return false + const index = messageIndex() + if (index < 0) return false if (!pendingUser()) return false const item = pending() if (!item) return false - return id > item.id + const messages = allMessages() + const active = messages.findIndex((m) => m.id === item.parentID) + return active >= 0 && index > active }) const parts = createMemo(() => { @@ -268,19 +267,7 @@ export function SessionTurn( () => { const msg = message() if (!msg) return emptyAssistant - - const messages = allMessages() ?? emptyMessages - const index = messageIndex() - if (index < 0) return emptyAssistant - - const result: AssistantMessage[] = [] - for (let i = index + 1; i < messages.length; i++) { - const item = messages[i] - if (!item) continue - if (item.role === "user") break - if (item.role === "assistant" && item.parentID === msg.id) result.push(item as AssistantMessage) - } - return result + return selectAssistants(allMessages(), msg.id) as AssistantMessage[] }, emptyAssistant, { equals: same }, diff --git a/packages/util/src/message.test.ts b/packages/util/src/message.test.ts new file mode 100644 index 00000000000..7ab17c4b6b7 --- /dev/null +++ b/packages/util/src/message.test.ts @@ -0,0 +1,36 @@ +import { describe, expect, test } from "bun:test" +import { selectAssistants, sortMessages, splitMessages } from "./message" + +describe("message", () => { + test("sortMessages uses created time before id", () => { + const result = sortMessages([ + { id: "msg_z", role: "assistant", time: { created: 20 } }, + { id: "msg_a", role: "user", time: { created: 10 } }, + ]) + expect(result.map((item) => item.id)).toEqual(["msg_a", "msg_z"]) + }) + + test("selectAssistants finds replies even when assistant id sorts before user id", () => { + const result = selectAssistants( + [ + { id: "msg_user", role: "user", time: { created: 10 } }, + { id: "msg_assistant", role: "assistant", parentID: "msg_user", time: { created: 11 } }, + ], + "msg_user", + ) + expect(result.map((item) => item.id)).toEqual(["msg_assistant"]) + }) + + test("splitMessages uses chronological order instead of id order", () => { + const result = splitMessages( + [ + { id: "msg_3", role: "user", time: { created: 30 } }, + { id: "msg_1", role: "user", time: { created: 10 } }, + { id: "msg_2", role: "user", time: { created: 20 } }, + ], + "msg_2", + ) + expect(result.before.map((item) => item.id)).toEqual(["msg_1"]) + expect(result.after.map((item) => item.id)).toEqual(["msg_2", "msg_3"]) + }) +}) diff --git a/packages/util/src/message.ts b/packages/util/src/message.ts new file mode 100644 index 00000000000..1b013b9f99f --- /dev/null +++ b/packages/util/src/message.ts @@ -0,0 +1,44 @@ +type Message = { + id: string + role?: string + parentID?: string + time?: { + created?: number + } +} + +function rank(message: Message) { + if (message.role === "user") return 0 + if (message.role === "assistant") return 1 + return 2 +} + +export function compareMessages(a: Message, b: Message) { + const at = a.time?.created ?? 0 + const bt = b.time?.created ?? 0 + if (at !== bt) return at - bt + + const ar = rank(a) + const br = rank(b) + if (ar !== br) return ar - br + + if (a.id < b.id) return -1 + if (a.id > b.id) return 1 + return 0 +} + +export function sortMessages(messages: readonly T[]) { + return messages.slice().sort(compareMessages) +} + +export function selectAssistants(messages: readonly T[], parentID: string) { + return sortMessages(messages.filter((message) => message.role === "assistant" && message.parentID === parentID)) +} + +export function splitMessages(messages: readonly T[], markerID?: string) { + const sorted = sortMessages(messages) + if (!markerID) return { before: sorted, after: [] as T[] } + const index = sorted.findIndex((message) => message.id === markerID) + if (index === -1) return { before: sorted, after: [] as T[] } + return { before: sorted.slice(0, index), after: sorted.slice(index) } +}