diff --git a/index.ts b/index.ts index 86d541e7..62540744 100644 --- a/index.ts +++ b/index.ts @@ -14,6 +14,7 @@ import { createCommandExecuteHandler, createSystemPromptHandler, createTextCompleteHandler, + createToolExecuteAfterHandler, } from "./lib/hooks" import { configureClientAuth, isSecureMode } from "./lib/auth" @@ -71,6 +72,7 @@ const plugin: Plugin = (async (ctx) => { logger.debug("Cached variant from chat.message hook", { variant: input.variant }) }, "experimental.text.complete": createTextCompleteHandler(), + "tool.execute.after": createToolExecuteAfterHandler(), "command.execute.before": createCommandExecuteHandler( ctx.client, state, diff --git a/lib/hooks.ts b/lib/hooks.ts index 3fcc5e63..97b4420c 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -11,7 +11,12 @@ import { injectExtendedSubAgentResults, stripStaleMetadata, } from "./messages" -import { buildToolIdList, isIgnoredUserMessage, stripHallucinations } from "./messages/utils" +import { + buildToolIdList, + isIgnoredUserMessage, + sanitizeVisibleOutput, + stripHallucinations, +} from "./messages/utils" import { checkSession } from "./state" import { renderSystemPrompt } from "./prompts" import { handleStatsCommand } from "./commands/stats" @@ -33,9 +38,6 @@ const INTERNAL_AGENT_SIGNATURES = [ "Summarize what was done in this conversation", ] -const DCP_MESSAGE_ID_TAG_REGEX = /(?:m\d+|b\d+)<\/dcp-message-id>/g -const DCP_SYSTEM_REMINDER_REGEX = /]*>[\s\S]*?<\/dcp-system-reminder>/g - function applyManualPrompt(state: SessionState, messages: WithParts[], logger: Logger): void { const pending = state.pendingManualTrigger if (!pending) { @@ -125,7 +127,7 @@ export function createChatMessageTransformHandler( prompts: PromptStore, hostPermissions: HostPermissionSnapshot, ) { - return async (input: {}, output: { messages: WithParts[] }) => { + return async (_input: {}, output: { messages: WithParts[] }) => { await checkSession(client, state, logger, output.messages, config.manualMode.enabled) syncCompressPermissionState(state, config, hostPermissions, output.messages) @@ -280,8 +282,16 @@ export function createTextCompleteHandler() { _input: { sessionID: string; messageID: string; partID: string }, output: { text: string }, ) => { - output.text = output.text - .replace(DCP_SYSTEM_REMINDER_REGEX, "") - .replace(DCP_MESSAGE_ID_TAG_REGEX, "") + output.text = sanitizeVisibleOutput(output.text) + } +} + +export function createToolExecuteAfterHandler() { + return async ( + _input: { tool: string; sessionID: string; callID: string }, + output: { title: string; output: string; metadata: unknown }, + ) => { + output.title = sanitizeVisibleOutput(output.title) + output.output = sanitizeVisibleOutput(output.output) } } diff --git a/lib/messages/inject/inject.ts b/lib/messages/inject/inject.ts index 17a4286a..3437a1fc 100644 --- a/lib/messages/inject/inject.ts +++ b/lib/messages/inject/inject.ts @@ -6,9 +6,7 @@ import { formatMessageIdTag } from "../../message-ids" import { compressPermission, getLastUserMessage } from "../../shared-utils" import { saveSessionState } from "../../state/persistence" import { - appendIdToTool, createSyntheticTextPart, - findLastToolPart, isIgnoredUserMessage, } from "../utils" import { @@ -164,11 +162,6 @@ export const injectMessageIds = ( continue } - const lastToolPart = findLastToolPart(message) - if (lastToolPart && appendIdToTool(lastToolPart, tag)) { - continue - } - const syntheticPart = createSyntheticTextPart(message, tag) const firstToolIndex = message.parts.findIndex((p) => p.type === "tool") if (firstToolIndex === -1) { diff --git a/lib/messages/utils.ts b/lib/messages/utils.ts index 03ce5690..c46e9a8d 100644 --- a/lib/messages/utils.ts +++ b/lib/messages/utils.ts @@ -61,38 +61,10 @@ export const createSyntheticTextPart = ( messageID: userInfo.id, type: "text" as const, text: content, + synthetic: true, } } -type MessagePart = WithParts["parts"][number] -type ToolPart = Extract - -export const appendIdToTool = (part: ToolPart, tag: string): boolean => { - if (part.type !== "tool") { - return false - } - if (part.state?.status !== "completed" || typeof part.state.output !== "string") { - return false - } - if (part.state.output.includes(tag)) { - return true - } - - part.state.output = `${part.state.output}${tag}` - return true -} - -export const findLastToolPart = (message: WithParts): ToolPart | null => { - for (let i = message.parts.length - 1; i >= 0; i--) { - const part = message.parts[i] - if (part.type === "tool") { - return part - } - } - - return null -} - export function buildToolIdList(state: SessionState, messages: WithParts[]): string[] { const toolIds: string[] = [] for (const msg of messages) { @@ -131,6 +103,10 @@ export const stripHallucinationsFromString = (text: string): string => { return text.replace(DCP_SYSTEM_REMINDER_REGEX, "").replace(DCP_MESSAGE_ID_TAG_REGEX, "") } +export const sanitizeVisibleOutput = (text: string): string => { + return stripHallucinationsFromString(text).replace(/\n{3,}/g, "\n\n").trimEnd() +} + export const stripHallucinations = (messages: WithParts[]): void => { for (const message of messages) { for (const part of message.parts) { diff --git a/tests/output-sanitization.test.ts b/tests/output-sanitization.test.ts new file mode 100644 index 00000000..228be0e2 --- /dev/null +++ b/tests/output-sanitization.test.ts @@ -0,0 +1,143 @@ +import assert from "node:assert/strict" +import test from "node:test" +import type { PluginConfig } from "../lib/config" +import { createTextCompleteHandler, createToolExecuteAfterHandler } from "../lib/hooks" +import { injectMessageIds } from "../lib/messages/inject/inject" +import { sanitizeVisibleOutput } from "../lib/messages/utils" +import { createSessionState, type WithParts } from "../lib/state" + +function buildConfig(): PluginConfig { + return { + enabled: true, + debug: false, + pruneNotification: "off", + pruneNotificationType: "chat", + commands: { + enabled: true, + protectedTools: [], + }, + manualMode: { + enabled: false, + automaticStrategies: true, + }, + turnProtection: { + enabled: false, + turns: 4, + }, + experimental: { + allowSubAgents: true, + customPrompts: false, + }, + protectedFilePatterns: [], + compress: { + permission: "allow", + showCompression: false, + maxContextLimit: 150000, + minContextLimit: 50000, + nudgeFrequency: 5, + iterationNudgeThreshold: 15, + nudgeForce: "soft", + flatSchema: false, + protectedTools: [], + protectUserMessages: false, + }, + strategies: { + deduplication: { + enabled: true, + protectedTools: [], + }, + supersedeWrites: { + enabled: true, + }, + purgeErrors: { + enabled: true, + turns: 4, + protectedTools: [], + }, + }, + } +} + +test("sanitizeVisibleOutput strips DCP metadata and trailing blank lines", () => { + const result = sanitizeVisibleOutput(`bun install +m0045 + + +hidden +`) + + assert.equal(result, "bun install") +}) + +test("tool.execute.after strips DCP metadata from visible tool output", async () => { + const handler = createToolExecuteAfterHandler() + const output = { + title: `bash +m0045`, + output: `bun install v1.3.10 +m0045`, + metadata: {}, + } + + await handler({ tool: "bash", sessionID: "ses_1", callID: "call_1" }, output) + + assert.equal(output.title, "bash") + assert.equal(output.output, "bun install v1.3.10") +}) + +test("experimental.text.complete strips DCP metadata from visible assistant text", async () => { + const handler = createTextCompleteHandler() + const output = { + text: `done +m0045 +hidden`, + } + + await handler({ sessionID: "ses_1", messageID: "msg_1", partID: "part_1" }, output) + + assert.equal(output.text, "done") +}) + +test("injectMessageIds keeps assistant tool output clean and inserts a synthetic text part", () => { + const state = createSessionState() + state.messageIds.byRawId.set("assistant-1", "m0045") + + const messages: WithParts[] = [ + { + info: { + id: "assistant-1", + role: "assistant", + sessionID: "ses_1", + agent: "assistant", + time: { created: 1 }, + } as WithParts["info"], + parts: [ + { + id: "tool-part-1", + sessionID: "ses_1", + messageID: "assistant-1", + type: "tool", + callID: "call_1", + tool: "bash", + state: { + status: "completed", + input: {}, + title: "bash", + output: "bun install v1.3.10", + metadata: {}, + time: { start: 1, end: 2 }, + }, + }, + ], + }, + ] + + injectMessageIds(state, buildConfig(), messages) + + assert.equal(messages[0].parts[0].type, "text") + assert.equal(messages[0].parts[0].synthetic, true) + assert.match(messages[0].parts[0].text, /m0045<\/dcp-message-id>/) + assert.equal(messages[0].parts[1].type, "tool") + assert.equal(messages[0].parts[1].state.status, "completed") + assert.equal(messages[0].parts[1].state.output, "bun install v1.3.10") +})