diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 01f1fb5f413..41db99fb0ed 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -758,6 +758,9 @@ importers: '@ai-sdk/groq': specifier: ^3.0.19 version: 3.0.19(zod@3.25.76) + '@ai-sdk/mistral': + specifier: ^3.0.0 + version: 3.0.16(zod@3.25.76) '@anthropic-ai/bedrock-sdk': specifier: ^0.10.2 version: 0.10.4 @@ -1432,6 +1435,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/mistral@3.0.16': + resolution: {integrity: sha512-8I/gxXJwghaDLbQQHMBwd61WxYz/PaFUFlG8I38daNYj5qRTMmQ5V10Idi6GJJC0wWEqQkal31lidm9+Y+u6TQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/openai-compatible@1.0.31': resolution: {integrity: sha512-znBvaVHM0M6yWNerIEy3hR+O8ZK2sPcE7e2cxfb6kYLEX3k//JH5VDnRnajseVofg7LXtTCFFdjsB7WLf1BdeQ==} engines: {node: '>=18'} @@ -11077,6 +11086,12 @@ snapshots: '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/mistral@3.0.16(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.6 + '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/openai-compatible@1.0.31(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 diff --git a/src/api/providers/__tests__/mistral.spec.ts b/src/api/providers/__tests__/mistral.spec.ts index 28aae09658e..0cac881dffe 100644 --- a/src/api/providers/__tests__/mistral.spec.ts +++ b/src/api/providers/__tests__/mistral.spec.ts @@ -1,59 +1,36 @@ -// Mock TelemetryService - must come before other imports -const mockCaptureException = vi.hoisted(() => vi.fn()) -vi.mock("@roo-code/telemetry", () => ({ - TelemetryService: { - instance: { - captureException: mockCaptureException, - }, - }, +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText, mockCreateMistral } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockCreateMistral: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "codestral-latest", + provider: "mistral", + })) + }), })) -// Mock Mistral client - must come before other imports -const mockCreate = vi.fn() -const mockComplete = vi.fn() -vi.mock("@mistralai/mistralai", () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - Mistral: vi.fn().mockImplementation(() => ({ - chat: { - stream: mockCreate.mockImplementation(async (_options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - }, - } - }, - } - return stream - }), - complete: mockComplete.mockImplementation(async (_options) => { - return { - choices: [ - { - message: { - content: "Test response", - }, - }, - ], - } - }), - }, - })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("@ai-sdk/mistral", () => ({ + createMistral: mockCreateMistral, +})) + import type { Anthropic } from "@anthropic-ai/sdk" -import type OpenAI from "openai" -import { MistralHandler } from "../mistral" + +import { mistralDefaultModelId, mistralModels, type MistralModelId } from "@roo-code/types" + import type { ApiHandlerOptions } from "../../../shared/api" -import type { ApiHandlerCreateMessageMetadata } from "../../index" -import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream" + +import { MistralHandler } from "../mistral" describe("MistralHandler", () => { let handler: MistralHandler @@ -61,15 +38,11 @@ describe("MistralHandler", () => { beforeEach(() => { mockOptions = { - apiModelId: "codestral-latest", // Update to match the actual model ID mistralApiKey: "test-api-key", - includeMaxTokens: true, - modelTemperature: 0, + apiModelId: "codestral-latest" as MistralModelId, } handler = new MistralHandler(mockOptions) - mockCreate.mockClear() - mockComplete.mockClear() - mockCaptureException.mockClear() + vi.clearAllMocks() }) describe("constructor", () => { @@ -78,32 +51,53 @@ describe("MistralHandler", () => { expect(handler.getModel().id).toBe(mockOptions.apiModelId) }) - it("should throw error if API key is missing", () => { - expect(() => { - new MistralHandler({ - ...mockOptions, - mistralApiKey: undefined, - }) - }).toThrow("Mistral API key is required") - }) - - it("should use custom base URL if provided", () => { - const customBaseUrl = "https://custom.mistral.ai/v1" - const handlerWithCustomUrl = new MistralHandler({ + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new MistralHandler({ ...mockOptions, - mistralCodestralUrl: customBaseUrl, + apiModelId: undefined, }) - expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler) + expect(handlerWithoutModel.getModel().id).toBe(mistralDefaultModelId) }) }) describe("getModel", () => { - it("should return correct model info", () => { + it("should return model info for valid model ID", () => { const model = handler.getModel() expect(model.id).toBe(mockOptions.apiModelId) expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(256_000) + expect(model.info.supportsImages).toBe(false) expect(model.info.supportsPromptCache).toBe(false) }) + + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new MistralHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") // Returns provided ID + expect(model.info).toBeDefined() + // Should have the same base properties as default model + expect(model.info.contextWindow).toBe(mistralModels[mistralDefaultModelId].contextWindow) + }) + + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new MistralHandler({ + ...mockOptions, + apiModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(mistralDefaultModelId) + expect(model.info).toBeDefined() + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) }) describe("createMessage", () => { @@ -111,389 +105,446 @@ describe("MistralHandler", () => { const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", - content: [{ type: "text", text: "Hello!" }], + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], }, ] - it("should create message successfully", async () => { - const iterator = handler.createMessage(systemPrompt, messages) - const result = await iterator.next() + it("should handle streaming responses", async () => { + // Mock the fullStream async generator + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + // Mock usage promise + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + }) + + it("should handle reasoning content in streaming responses", async () => { + // Mock the fullStream async generator with reasoning content + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think about this..." } + yield { type: "reasoning", text: " I'll analyze step by step." } + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + details: { + reasoningTokens: 15, + }, + }) - expect(mockCreate).toHaveBeenCalledWith( + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should have reasoning chunks + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks.length).toBe(2) + expect(reasoningChunks[0].text).toBe("Let me think about this...") + expect(reasoningChunks[1].text).toBe(" I'll analyze step by step.") + + // Should also have text chunks + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks.length).toBe(1) + expect(textChunks[0].text).toBe("Test response") + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion") + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - model: mockOptions.apiModelId, - messages: expect.any(Array), - maxTokens: expect.any(Number), - temperature: 0, - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - tools: expect.any(Array), - toolChoice: "any", + prompt: "Test prompt", }), ) - - expect(result.value).toBeDefined() - expect(result.done).toBe(false) }) + }) - it("should handle streaming response correctly", async () => { - const iterator = handler.createMessage(systemPrompt, messages) - const results: ApiStreamTextChunk[] = [] - - for await (const chunk of iterator) { - if ("text" in chunk) { - results.push(chunk as ApiStreamTextChunk) + describe("processUsageMetrics", () => { + it("should correctly process usage metrics", () => { + // We need to access the protected method, so we'll create a test subclass + class TestMistralHandler extends MistralHandler { + public testProcessUsageMetrics(usage: any) { + return this.processUsageMetrics(usage) } } - expect(results.length).toBeGreaterThan(0) - expect(results[0].text).toBe("Test response") - }) + const testHandler = new TestMistralHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 20, + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) - it("should handle errors gracefully", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) - await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error") + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheReadTokens).toBe(20) + expect(result.reasoningTokens).toBe(30) }) - it("should handle thinking content as reasoning chunks", async () => { - // Mock stream with thinking content matching new SDK structure - mockCreate.mockImplementationOnce(async (_options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { - content: [ - { - type: "thinking", - thinking: [{ type: "text", text: "Let me think about this..." }], - }, - { type: "text", text: "Here's the answer" }, - ], - }, - index: 0, - }, - ], - }, - } - }, + it("should handle missing cache metrics gracefully", () => { + class TestMistralHandler extends MistralHandler { + public testProcessUsageMetrics(usage: any) { + return this.processUsageMetrics(usage) } - return stream - }) + } + + const testHandler = new TestMistralHandler(mockOptions) - const iterator = handler.createMessage(systemPrompt, messages) - const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = [] + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheReadTokens).toBeUndefined() + expect(result.reasoningTokens).toBeUndefined() + }) + }) - for await (const chunk of iterator) { - if ("text" in chunk) { - results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk) + describe("getMaxOutputTokens", () => { + it("should return maxTokens from model info", () => { + class TestMistralHandler extends MistralHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() } } - expect(results).toHaveLength(2) - expect(results[0]).toEqual({ type: "reasoning", text: "Let me think about this..." }) - expect(results[1]).toEqual({ type: "text", text: "Here's the answer" }) + const testHandler = new TestMistralHandler(mockOptions) + const result = testHandler.testGetMaxOutputTokens() + + // codestral-latest maxTokens is 8192 + expect(result).toBe(8192) }) - it("should handle mixed content arrays correctly", async () => { - // Mock stream with mixed content matching new SDK structure - mockCreate.mockImplementationOnce(async (_options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { - content: [ - { type: "text", text: "First text" }, - { - type: "thinking", - thinking: [{ type: "text", text: "Some reasoning" }], - }, - { type: "text", text: "Second text" }, - ], - }, - index: 0, - }, - ], - }, - } - }, + it("should use modelMaxTokens when provided", () => { + class TestMistralHandler extends MistralHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() } - return stream + } + + const customMaxTokens = 5000 + const testHandler = new TestMistralHandler({ + ...mockOptions, + modelMaxTokens: customMaxTokens, }) - const iterator = handler.createMessage(systemPrompt, messages) - const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = [] + const result = testHandler.testGetMaxOutputTokens() + expect(result).toBe(customMaxTokens) + }) - for await (const chunk of iterator) { - if ("text" in chunk) { - results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk) + it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => { + class TestMistralHandler extends MistralHandler { + public testGetMaxOutputTokens() { + return this.getMaxOutputTokens() } } - expect(results).toHaveLength(3) - expect(results[0]).toEqual({ type: "text", text: "First text" }) - expect(results[1]).toEqual({ type: "reasoning", text: "Some reasoning" }) - expect(results[2]).toEqual({ type: "text", text: "Second text" }) + const testHandler = new TestMistralHandler(mockOptions) + const result = testHandler.testGetMaxOutputTokens() + + // codestral-latest has maxTokens of 8192 + expect(result).toBe(8192) }) }) - describe("native tool calling", () => { + describe("tool handling", () => { const systemPrompt = "You are a helpful assistant." const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", - content: [{ type: "text", text: "What's the weather?" }], + content: [{ type: "text" as const, text: "Hello!" }], }, ] - const mockTools: OpenAI.Chat.ChatCompletionTool[] = [ - { - type: "function", - function: { - name: "get_weather", - description: "Get the current weather", - parameters: { - type: "object", - properties: { - location: { type: "string" }, + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, }, - required: ["location"], }, - }, - }, - ] + ], + }) - it("should include tools in request by default (native is default)", async () => { - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", - tools: mockTools, + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } - const iterator = handler.createMessage(systemPrompt, messages, metadata) - await iterator.next() + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "get_weather", - description: "Get the current weather", - parameters: expect.any(Object), - }), - }), - ]), - toolChoice: "any", - }), - ) + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") }) - it("should always include tools in request (tools are always present after PR #10841)", async () => { - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + // tool-call events are intentionally ignored because tool-input-start/delta/end + // already provide complete tool call information. Emitting tool-call would cause + // duplicate tools in the UI for AI SDK providers. + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } } - const iterator = handler.createMessage(systemPrompt, messages, metadata) - await iterator.next() + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.any(Array), - toolChoice: "any", - }), - ) - }) + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) - it("should handle tool calls in streaming response", async () => { - // Mock stream with tool calls - mockCreate.mockImplementationOnce(async (_options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { - toolCalls: [ - { - id: "call_123", - type: "function", - function: { - name: "get_weather", - arguments: '{"location":"New York"}', - }, - }, - ], - }, - index: 0, - }, - ], + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], }, - } + }, }, - } - return stream + ], }) - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", - tools: mockTools, + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) } - const iterator = handler.createMessage(systemPrompt, messages, metadata) - const results: ApiStreamToolCallPartialChunk[] = [] + // tool-call events are ignored, so no tool_call chunks should be emitted + const toolCallChunks = chunks.filter((c) => c.type === "tool_call") + expect(toolCallChunks.length).toBe(0) + }) + }) - for await (const chunk of iterator) { - if (chunk.type === "tool_call_partial") { - results.push(chunk) + describe("mapToolChoice", () => { + it("should handle string tool choices", () => { + class TestMistralHandler extends MistralHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } } - expect(results).toHaveLength(1) - expect(results[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "get_weather", - arguments: '{"location":"New York"}', - }) + const testHandler = new TestMistralHandler(mockOptions) + + expect(testHandler.testMapToolChoice("auto")).toBe("auto") + expect(testHandler.testMapToolChoice("none")).toBe("none") + expect(testHandler.testMapToolChoice("required")).toBe("required") + expect(testHandler.testMapToolChoice("any")).toBe("required") + expect(testHandler.testMapToolChoice("unknown")).toBe("auto") }) - it("should handle multiple tool calls in a single response", async () => { - // Mock stream with multiple tool calls - mockCreate.mockImplementationOnce(async (_options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { - toolCalls: [ - { - id: "call_1", - type: "function", - function: { - name: "get_weather", - arguments: '{"location":"NYC"}', - }, - }, - { - id: "call_2", - type: "function", - function: { - name: "get_weather", - arguments: '{"location":"LA"}', - }, - }, - ], - }, - index: 0, - }, - ], - }, - } - }, + it("should handle object tool choice with function name", () => { + class TestMistralHandler extends MistralHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } - return stream - }) - - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", - tools: mockTools, } - const iterator = handler.createMessage(systemPrompt, messages, metadata) - const results: ApiStreamToolCallPartialChunk[] = [] + const testHandler = new TestMistralHandler(mockOptions) - for await (const chunk of iterator) { - if (chunk.type === "tool_call_partial") { - results.push(chunk) + const result = testHandler.testMapToolChoice({ + type: "function", + function: { name: "my_tool" }, + }) + + expect(result).toEqual({ type: "tool", toolName: "my_tool" }) + }) + + it("should return undefined for null or undefined", () => { + class TestMistralHandler extends MistralHandler { + public testMapToolChoice(toolChoice: any) { + return this.mapToolChoice(toolChoice) } } - expect(results).toHaveLength(2) - expect(results[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "call_1", - name: "get_weather", - arguments: '{"location":"NYC"}', - }) - expect(results[1]).toEqual({ - type: "tool_call_partial", - index: 1, - id: "call_2", - name: "get_weather", - arguments: '{"location":"LA"}', - }) + const testHandler = new TestMistralHandler(mockOptions) + + expect(testHandler.testMapToolChoice(null)).toBeUndefined() + expect(testHandler.testMapToolChoice(undefined)).toBeUndefined() }) + }) - it("should always set toolChoice to 'any' when tools are provided", async () => { - // Even if tool_choice is provided in metadata, we override it to "any" - const metadata: ApiHandlerCreateMessageMetadata = { - taskId: "test-task", - tools: mockTools, - tool_choice: "auto", // This should be ignored - } + describe("Codestral URL handling", () => { + beforeEach(() => { + mockCreateMistral.mockClear() + }) - const iterator = handler.createMessage(systemPrompt, messages, metadata) - await iterator.next() + it("should use default Codestral URL for codestral models", () => { + new MistralHandler({ + ...mockOptions, + apiModelId: "codestral-latest", + }) - expect(mockCreate).toHaveBeenCalledWith( + expect(mockCreateMistral).toHaveBeenCalledWith( expect.objectContaining({ - toolChoice: "any", + baseURL: "https://codestral.mistral.ai/v1", }), ) }) - }) - describe("completePrompt", () => { - it("should complete prompt successfully", async () => { - const prompt = "Test prompt" - const result = await handler.completePrompt(prompt) - - expect(mockComplete).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: "user", content: prompt }], - temperature: 0, + it("should use custom Codestral URL when provided", () => { + new MistralHandler({ + ...mockOptions, + apiModelId: "codestral-latest", + mistralCodestralUrl: "https://custom.codestral.url/v1", }) - expect(result).toBe("Test response") + expect(mockCreateMistral).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://custom.codestral.url/v1", + }), + ) }) - it("should filter out thinking content in completePrompt", async () => { - mockComplete.mockImplementationOnce(async (_options) => { - return { - choices: [ - { - message: { - content: [ - { type: "thinking", text: "Let me think..." }, - { type: "text", text: "Answer part 1" }, - { type: "text", text: "Answer part 2" }, - ], - }, - }, - ], - } + it("should use default Mistral URL for non-codestral models", () => { + new MistralHandler({ + ...mockOptions, + apiModelId: "mistral-large-latest", }) - const prompt = "Test prompt" - const result = await handler.completePrompt(prompt) - - expect(result).toBe("Answer part 1Answer part 2") - }) - - it("should handle errors in completePrompt", async () => { - mockComplete.mockRejectedValueOnce(new Error("API Error")) - await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error") + expect(mockCreateMistral).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.mistral.ai/v1", + }), + ) }) }) }) diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index e0e19298f42..7ce2fc4586d 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -1,224 +1,201 @@ import { Anthropic } from "@anthropic-ai/sdk" -import { Mistral } from "@mistralai/mistralai" -import OpenAI from "openai" +import { createMistral } from "@ai-sdk/mistral" +import { streamText, generateText, ToolSet, LanguageModel } from "ai" import { - type MistralModelId, - mistralDefaultModelId, mistralModels, + mistralDefaultModelId, + type MistralModelId, + type ModelInfo, MISTRAL_DEFAULT_TEMPERATURE, - ApiProviderError, } from "@roo-code/types" -import { TelemetryService } from "@roo-code/telemetry" -import { ApiHandlerOptions } from "../../shared/api" - -import { convertToMistralMessages } from "../transform/mistral-format" -import { ApiStream } from "../transform/stream" -import { handleProviderError } from "./utils/error-handler" +import type { ApiHandlerOptions } from "../../shared/api" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -// Type helper to handle thinking chunks from Mistral API -// The SDK includes ThinkChunk but TypeScript has trouble with the discriminated union -type ContentChunkWithThinking = { - type: string - text?: string - thinking?: Array<{ type: string; text?: string }> -} - -// Type for Mistral tool calls in stream delta -type MistralToolCall = { - id?: string - type?: string - function?: { - name?: string - arguments?: string - } -} - -// Type for Mistral tool definition - matches Mistral SDK Tool type -type MistralTool = { - type: "function" - function: { - name: string - description?: string - parameters: Record - } -} - +/** + * Mistral provider using the dedicated @ai-sdk/mistral package. + * Provides access to Mistral AI models including Codestral, Mistral Large, and more. + */ export class MistralHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - private client: Mistral - private readonly providerName = "Mistral" + protected provider: ReturnType constructor(options: ApiHandlerOptions) { super() + this.options = options - if (!options.mistralApiKey) { - throw new Error("Mistral API key is required") - } + const modelId = options.apiModelId ?? mistralDefaultModelId - // Set default model ID if not provided. - const apiModelId = options.apiModelId || mistralDefaultModelId - this.options = { ...options, apiModelId } + // Determine the base URL based on the model (Codestral uses a different endpoint) + const baseURL = modelId.startsWith("codestral-") + ? options.mistralCodestralUrl || "https://codestral.mistral.ai/v1" + : "https://api.mistral.ai/v1" - this.client = new Mistral({ - serverURL: apiModelId.startsWith("codestral-") - ? this.options.mistralCodestralUrl || "https://codestral.mistral.ai" - : "https://api.mistral.ai", - apiKey: this.options.mistralApiKey, + // Create the Mistral provider using AI SDK + this.provider = createMistral({ + apiKey: options.mistralApiKey ?? "not-provided", + baseURL, + headers: DEFAULT_HEADERS, }) } - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const { id: model, info, maxTokens, temperature } = this.getModel() - - // Build request options - const requestOptions: { - model: string - messages: ReturnType - maxTokens: number - temperature: number - tools?: MistralTool[] - toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } } - } = { - model, - messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)], - maxTokens: maxTokens ?? info.maxTokens, - temperature, - } - - requestOptions.tools = this.convertToolsForMistral(metadata?.tools ?? []) - // Always use "any" to require tool use - requestOptions.toolChoice = "any" + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = (this.options.apiModelId ?? mistralDefaultModelId) as MistralModelId + const info = mistralModels[id as keyof typeof mistralModels] || mistralModels[mistralDefaultModelId] + const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + return { id, info, ...params } + } - // Temporary debug log for QA - // console.log("[MISTRAL DEBUG] Raw API request body:", requestOptions) + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel(): LanguageModel { + const { id } = this.getModel() + // Type assertion needed due to version mismatch between @ai-sdk/mistral and ai packages + return this.provider(id) as unknown as LanguageModel + } - let response - try { - response = await this.client.chat.stream(requestOptions) - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, model, "createMessage") - TelemetryService.instance.captureException(apiError) - throw new Error(`Mistral completion error: ${errorMessage}`) + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics(usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens: usage.details?.cachedInputTokens, + reasoningTokens: usage.details?.reasoningTokens, } + } - for await (const event of response) { - const delta = event.data.choices[0]?.delta - - if (delta?.content) { - if (typeof delta.content === "string") { - // Handle string content as text - yield { type: "text", text: delta.content } - } else if (Array.isArray(delta.content)) { - // Handle array of content chunks - // The SDK v1.9.18 supports ThinkChunk with type "thinking" - for (const chunk of delta.content as ContentChunkWithThinking[]) { - if (chunk.type === "thinking" && chunk.thinking) { - // Handle thinking content as reasoning chunks - // ThinkChunk has a 'thinking' property that contains an array of text/reference chunks - for (const thinkingPart of chunk.thinking) { - if (thinkingPart.type === "text" && thinkingPart.text) { - yield { type: "reasoning", text: thinkingPart.text } - } - } - } else if (chunk.type === "text" && chunk.text) { - // Handle text content normally - yield { type: "text", text: chunk.text } - } - } - } - } + /** + * Map OpenAI tool_choice to AI SDK toolChoice format. + */ + protected mapToolChoice( + toolChoice: any, + ): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined { + if (!toolChoice) { + return undefined + } - // Handle tool calls in stream - // Mistral SDK provides tool_calls in delta similar to OpenAI format - const toolCalls = (delta as { toolCalls?: MistralToolCall[] })?.toolCalls - if (toolCalls) { - for (let i = 0; i < toolCalls.length; i++) { - const toolCall = toolCalls[i] - yield { - type: "tool_call_partial", - index: i, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } + // Handle string values + if (typeof toolChoice === "string") { + switch (toolChoice) { + case "auto": + return "auto" + case "none": + return "none" + case "required": + case "any": + return "required" + default: + return "auto" } + } - if (event.data.usage) { - yield { - type: "usage", - inputTokens: event.data.usage.promptTokens || 0, - outputTokens: event.data.usage.completionTokens || 0, - } + // Handle object values (OpenAI ChatCompletionNamedToolChoice format) + if (typeof toolChoice === "object" && "type" in toolChoice) { + if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) { + return { type: "tool", toolName: toolChoice.function.name } } } + + return undefined } /** - * Convert OpenAI tool definitions to Mistral format. - * Mistral uses the same format as OpenAI for function tools. + * Get the max tokens parameter to include in the request. */ - private convertToolsForMistral(tools: OpenAI.Chat.ChatCompletionTool[]): MistralTool[] { - return tools - .filter((tool) => tool.type === "function") - .map((tool) => ({ - type: "function" as const, - function: { - name: tool.function.name, - description: tool.function.description, - // Mistral SDK requires parameters to be defined, use empty object as fallback - parameters: (tool.function.parameters as Record) || {}, - }, - })) + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined } - override getModel() { - const id = this.options.apiModelId ?? mistralDefaultModelId - const info = mistralModels[id as MistralModelId] ?? mistralModels[mistralDefaultModelId] - - // @TODO: Move this to the `getModelParams` function. - const maxTokens = this.options.includeMaxTokens ? info.maxTokens : undefined - const temperature = this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE - - return { id, info, maxTokens, temperature } - } + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + // Use MISTRAL_DEFAULT_TEMPERATURE (1) as fallback to match original behavior + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: this.mapToolChoice(metadata?.tool_choice), + } - async completePrompt(prompt: string): Promise { - const { id: model, temperature } = this.getModel() + // Use streamText for streaming responses + const result = streamText(requestOptions) try { - const response = await this.client.chat.complete({ - model, - messages: [{ role: "user", content: prompt }], - temperature, - }) - - const content = response.choices?.[0]?.message.content - - if (Array.isArray(content)) { - // Only return text content, filter out thinking content for non-streaming - return (content as ContentChunkWithThinking[]) - .filter((c) => c.type === "text" && c.text) - .map((c) => c.text || "") - .join("") + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } } - return content || "" + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, model, "completePrompt") - TelemetryService.instance.captureException(apiError) - throw new Error(`Mistral completion error: ${errorMessage}`) + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "Mistral") } } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const languageModel = this.getLanguageModel() + + // Use MISTRAL_DEFAULT_TEMPERATURE (1) as fallback to match original behavior + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE, + }) + + return text + } } diff --git a/src/core/tools/__tests__/useMcpToolTool.spec.ts b/src/core/tools/__tests__/useMcpToolTool.spec.ts index 27a991456ae..5ee826774f4 100644 --- a/src/core/tools/__tests__/useMcpToolTool.spec.ts +++ b/src/core/tools/__tests__/useMcpToolTool.spec.ts @@ -676,14 +676,12 @@ describe("useMcpToolTool", () => { mockProviderRef.deref.mockReturnValue({ getMcpHub: () => ({ callTool: vi.fn().mockResolvedValue(mockToolResult), - getAllServers: vi - .fn() - .mockReturnValue([ - { - name: "figma-server", - tools: [{ name: "get_screenshot", description: "Get screenshot" }], - }, - ]), + getAllServers: vi.fn().mockReturnValue([ + { + name: "figma-server", + tools: [{ name: "get_screenshot", description: "Get screenshot" }], + }, + ]), }), postMessageToWebview: vi.fn(), }) @@ -790,14 +788,12 @@ describe("useMcpToolTool", () => { mockProviderRef.deref.mockReturnValue({ getMcpHub: () => ({ callTool: vi.fn().mockResolvedValue(mockToolResult), - getAllServers: vi - .fn() - .mockReturnValue([ - { - name: "figma-server", - tools: [{ name: "get_screenshot", description: "Get screenshot" }], - }, - ]), + getAllServers: vi.fn().mockReturnValue([ + { + name: "figma-server", + tools: [{ name: "get_screenshot", description: "Get screenshot" }], + }, + ]), }), postMessageToWebview: vi.fn(), }) @@ -852,14 +848,12 @@ describe("useMcpToolTool", () => { mockProviderRef.deref.mockReturnValue({ getMcpHub: () => ({ callTool: vi.fn().mockResolvedValue(mockToolResult), - getAllServers: vi - .fn() - .mockReturnValue([ - { - name: "figma-server", - tools: [{ name: "get_screenshots", description: "Get screenshots" }], - }, - ]), + getAllServers: vi.fn().mockReturnValue([ + { + name: "figma-server", + tools: [{ name: "get_screenshots", description: "Get screenshots" }], + }, + ]), }), postMessageToWebview: vi.fn(), }) diff --git a/src/package.json b/src/package.json index 624b5b5b16e..98bd1d1b3ec 100644 --- a/src/package.json +++ b/src/package.json @@ -454,6 +454,7 @@ "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26", "@ai-sdk/groq": "^3.0.19", + "@ai-sdk/mistral": "^3.0.0", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/vertex-sdk": "^0.7.0",