diff --git a/core/tools/systemMessageTools/detectToolCallStart.ts b/core/tools/systemMessageTools/detectToolCallStart.ts index e143411dda6..d96ed1d1fac 100644 --- a/core/tools/systemMessageTools/detectToolCallStart.ts +++ b/core/tools/systemMessageTools/detectToolCallStart.ts @@ -1,15 +1,25 @@ import { SystemMessageToolsFramework } from "./types"; +interface DetectToolCallStartOptions { + allowAlternateStarts?: boolean; +} + export function detectToolCallStart( buffer: string, toolCallFramework: SystemMessageToolsFramework, + options: DetectToolCallStartOptions = {}, ) { + const allowAlternateStarts = options.allowAlternateStarts ?? true; const starts = toolCallFramework.acceptedToolCallStarts; let modifiedBuffer = buffer; let isInToolCall = false; let isInPartialStart = false; const lowerCaseBuffer = buffer.toLowerCase(); for (let i = 0; i < starts.length; i++) { + if (i !== 0 && !allowAlternateStarts) { + continue; + } + const [start, replacement] = starts[i]; if (lowerCaseBuffer.startsWith(start)) { // for non-standard cases like no ```tool codeblock, etc, replace before adding to buffer, case insensitive diff --git a/core/tools/systemMessageTools/interceptSystemToolCalls.ts b/core/tools/systemMessageTools/interceptSystemToolCalls.ts index a89fbdcfbb5..bf8fe4266d3 100644 --- a/core/tools/systemMessageTools/interceptSystemToolCalls.ts +++ b/core/tools/systemMessageTools/interceptSystemToolCalls.ts @@ -28,6 +28,7 @@ export async function* interceptSystemToolCalls( ): AsyncGenerator { let buffer = ""; let parseState: ToolCallParseState | undefined; + let sawAssistantNonWhitespaceText = false; while (true) { const result = await messageGenerator.next(); @@ -71,7 +72,10 @@ export async function* interceptSystemToolCalls( buffer += chunk; if (!parseState) { const { isInPartialStart, isInToolCall, modifiedBuffer } = - detectToolCallStart(buffer, systemToolFramework); + detectToolCallStart(buffer, systemToolFramework, { + // Only allow loose "TOOL_NAME:" starts at the beginning of assistant output. + allowAlternateStarts: !sawAssistantNonWhitespaceText, + }); if (isInPartialStart) { continue; @@ -109,6 +113,9 @@ export async function* interceptSystemToolCalls( content: [{ type: "text", text: buffer }], }, ]; + if (/\S/.test(buffer)) { + sawAssistantNonWhitespaceText = true; + } } buffer = ""; } diff --git a/core/tools/systemMessageTools/toolCodeblocks/detectToolCallStart.vitest.ts b/core/tools/systemMessageTools/toolCodeblocks/detectToolCallStart.vitest.ts index 297c58ec456..0adf6d0c321 100644 --- a/core/tools/systemMessageTools/toolCodeblocks/detectToolCallStart.vitest.ts +++ b/core/tools/systemMessageTools/toolCodeblocks/detectToolCallStart.vitest.ts @@ -75,4 +75,15 @@ describe("detectToolCallStart", () => { expect(result.isInPartialStart).toBe(false); expect(result.modifiedBuffer).toBe(buffer); }); + + it("skips non-standard starts when alternate starts are disabled", () => { + const buffer = "TOOL_NAME: example_tool"; + const result = detectToolCallStart(buffer, framework, { + allowAlternateStarts: false, + }); + + expect(result.isInToolCall).toBe(false); + expect(result.isInPartialStart).toBe(false); + expect(result.modifiedBuffer).toBe(buffer); + }); }); diff --git a/core/tools/systemMessageTools/toolCodeblocks/interceptSystemToolCalls.vitest.ts b/core/tools/systemMessageTools/toolCodeblocks/interceptSystemToolCalls.vitest.ts index a636cc6a9da..b03e0309ce4 100644 --- a/core/tools/systemMessageTools/toolCodeblocks/interceptSystemToolCalls.vitest.ts +++ b/core/tools/systemMessageTools/toolCodeblocks/interceptSystemToolCalls.vitest.ts @@ -179,9 +179,8 @@ describe("interceptSystemToolCalls", () => { ).toBe("}"); }); - it("processes tool_name without codeblock format", async () => { + it("processes tool_name without codeblock format at assistant output start", async () => { const messages: ChatMessage[][] = [ - [{ role: "assistant", content: "I'll help you with that.\n" }], [{ role: "assistant", content: "TOOL_NAME: test_tool\n" }], [{ role: "assistant", content: "BEGIN_ARG: arg1\n" }], [{ role: "assistant", content: "value1\n" }], @@ -194,30 +193,7 @@ describe("interceptSystemToolCalls", () => { framework, ); - // First chunk should be normal text let result = await generator.next(); - expect(result.value).toEqual([ - { - role: "assistant", - content: [{ type: "text", text: "I'll help you with that." }], - }, - ]); - - result = await generator.next(); - expect(result.value).toEqual([ - { - role: "assistant", - content: [ - { - type: "text", - text: "\n", - }, - ], - }, - ]); - - // The system should detect the tool_name format and convert it - result = await generator.next(); expect( (result.value as AssistantChatMessage[])[0].toolCalls?.[0].function?.name, ).toBe("test_tool"); @@ -242,6 +218,43 @@ describe("interceptSystemToolCalls", () => { ).toBe("}"); }); + it("does not intercept quoted tool syntax in explanatory text", async () => { + const messages: ChatMessage[][] = [ + [{ role: "assistant", content: "Here is the syntax:\n" }], + [{ role: "assistant", content: "TOOL_NAME: read_file\n" }], + [{ role: "assistant", content: "BEGIN_ARG: filepath\n" }], + [{ role: "assistant", content: "path/to/the_file.txt\n" }], + [{ role: "assistant", content: "END_ARG\n" }], + ]; + + const generator = interceptSystemToolCalls( + createAsyncGenerator(messages), + abortController, + framework, + ); + + const outputChunks: string[] = []; + while (true) { + const result = await generator.next(); + if (result.done || !result.value) { + break; + } + + const chunkText = ( + (result.value as AssistantChatMessage[])[0].content as { + type: "text"; + text: string; + }[] + )[0].text; + outputChunks.push(chunkText); + expect((result.value as AssistantChatMessage[])[0].toolCalls).toBeFalsy(); + } + + expect(outputChunks.join("")).toBe( + "Here is the syntax:\nTOOL_NAME: read_file\nBEGIN_ARG: filepath\npath/to/the_file.txt\nEND_ARG\n", + ); + }); + it("ignores content after a tool call", async () => { const messages: ChatMessage[][] = [ [{ role: "assistant", content: "```tool\n" }],