From 29724d9b4de73afd55fee40b9bbf04455166197b Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 2 Feb 2026 15:18:23 +0000 Subject: [PATCH 1/2] feat: migrate SambaNova provider to AI SDK - Replace BaseOpenAiCompatibleProvider with BaseProvider + AI SDK - Use dedicated sambanova-ai-provider package - Implement createMessage() using streamText with streaming support - Implement completePrompt() using generateText - Add processUsageMetrics() for cache token handling from providerMetadata - Add getMaxOutputTokens() and getLanguageModel() helpers - Use shared AI SDK utilities (convertToAiSdkMessages, handleAiSdkError, etc.) - Set default temperature to 0.5 following Fireworks pattern - Update tests to match AI SDK streaming behavior with async generators - Add comprehensive tests for usage metrics, temperature handling, tool handling, and error handling --- pnpm-lock.yaml | 90 ++- src/api/providers/__tests__/sambanova.spec.ts | 690 +++++++++++++++--- src/api/providers/sambanova.ts | 176 ++++- src/package.json | 1 + 4 files changed, 833 insertions(+), 124 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 41db99fb0ed..4cf3a3627e0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -760,7 +760,7 @@ importers: version: 3.0.19(zod@3.25.76) '@ai-sdk/mistral': specifier: ^3.0.0 - version: 3.0.16(zod@3.25.76) + version: 3.0.18(zod@3.25.76) '@anthropic-ai/bedrock-sdk': specifier: ^0.10.2 version: 0.10.4 @@ -938,6 +938,9 @@ importers: safe-stable-stringify: specifier: ^2.5.0 version: 2.5.0 + sambanova-ai-provider: + specifier: ^1.2.2 + version: 1.2.2(zod@3.25.76) sanitize-filename: specifier: ^1.6.3 version: 1.6.3 @@ -1435,8 +1438,14 @@ packages: peerDependencies: zod: 3.25.76 - '@ai-sdk/mistral@3.0.16': - resolution: {integrity: sha512-8I/gxXJwghaDLbQQHMBwd61WxYz/PaFUFlG8I38daNYj5qRTMmQ5V10Idi6GJJC0wWEqQkal31lidm9+Y+u6TQ==} + '@ai-sdk/mistral@3.0.18': + resolution: {integrity: sha512-k8nCBBVGOzBigNwBO5kREzsP/e+C3npcL7jt19ZdicIbZ6rvmnSIRI90iENyS9T10vM7sjrXoCpgZSYgJB2pJQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + + '@ai-sdk/openai-compatible@1.0.11': + resolution: {integrity: sha512-eRD6dZviy31KYz4YvxAR/c6UEYx3p4pCiWZeDdYdAHj0rn8xZlGVxtQRs1qynhz6IYGOo4aLBf9zVW5w0tI/Uw==} engines: {node: '>=18'} peerDependencies: zod: 3.25.76 @@ -1459,6 +1468,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.5': + resolution: {integrity: sha512-HliwB/yzufw3iwczbFVE2Fiwf1XqROB/I6ng8EKUsPM5+2wnIa8f4VbljZcDx+grhFrPV+PnRZH7zBqi8WZM7Q==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.10': resolution: {integrity: sha512-VeDAiCH+ZK8Xs4hb9Cw7pHlujWNL52RKe8TExOkrw6Ir1AmfajBZTb9XUdKOZO08RwQElIKA8+Ltm+Gqfo8djQ==} engines: {node: '>=18'} @@ -1471,6 +1486,16 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.13': + resolution: {integrity: sha512-HHG72BN4d+OWTcq2NwTxOm/2qvk1duYsnhCDtsbYwn/h/4zeqURu1S0+Cn0nY2Ysq9a9HGKvrYuMn9bgFhR2Og==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + + '@ai-sdk/provider@2.0.0': + resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==} + engines: {node: '>=18'} + '@ai-sdk/provider@2.0.1': resolution: {integrity: sha512-KCUwswvsC5VsW2PWFqF8eJgSCu5Ysj7m1TxiHTVA6g7k360bk0RNQENT8KTMAYEs+8fWPD3Uu4dEmzGHc+jGng==} engines: {node: '>=18'} @@ -1483,6 +1508,10 @@ packages: resolution: {integrity: sha512-hSfoJtLtpMd7YxKM+iTqlJ0ZB+kJ83WESMiWuWrNVey3X8gg97x0OdAAaeAeclZByCX3UdPOTqhvJdK8qYA3ww==} engines: {node: '>=18'} + '@ai-sdk/provider@3.0.7': + resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==} + engines: {node: '>=18'} + '@alcalzone/ansi-tokenize@0.2.3': resolution: {integrity: sha512-jsElTJ0sQ4wHRz+C45tfect76BwbTbgkgKByOzpCN9xG61N5V6u/glvg1CsNJhq2xJIFpKHSwG3D2wPPuEYOrQ==} engines: {node: '>=18'} @@ -6066,6 +6095,10 @@ packages: resolution: {integrity: sha512-7GO6HghkA5fYG9TYnNxi14/7K9f5occMlp3zXAuSxn7CKCxt9xbNWG7yF8hTCSUchlfWSe3uLmlPfigevRItzQ==} engines: {node: '>=12'} + dotenv@16.4.5: + resolution: {integrity: sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==} + engines: {node: '>=12'} + dotenv@16.5.0: resolution: {integrity: sha512-m/C+AwOAr9/W1UOIZUo232ejMNnJAJtYQjUbHoNTBNTJSvqzzDh7vnrei3o3r3m9blf6ZoDkvcw0VmozNRFJxg==} engines: {node: '>=12'} @@ -9520,6 +9553,9 @@ packages: safer-buffer@2.1.2: resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + sambanova-ai-provider@1.2.2: + resolution: {integrity: sha512-MU/D+9GCg6me0guDRPw/x0N8cnpkOkv03FR7QXdrcinX0hprS7bsZXXTYEz81Svc+oVwXDZwh0v+Sd5pUxV3mg==} + sanitize-filename@1.6.3: resolution: {integrity: sha512-y/52Mcy7aw3gRm7IrcGDFx/bCk4AhRh2eI9luHOQM86nZsqwiRkkq2GekHXBBD+SmPidc8i2PqtYZl+pWJ8Oeg==} @@ -11086,10 +11122,16 @@ snapshots: '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) zod: 3.25.76 - '@ai-sdk/mistral@3.0.16(zod@3.25.76)': + '@ai-sdk/mistral@3.0.18(zod@3.25.76)': dependencies: - '@ai-sdk/provider': 3.0.6 - '@ai-sdk/provider-utils': 4.0.11(zod@3.25.76) + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + + '@ai-sdk/openai-compatible@1.0.11(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 2.0.0 + '@ai-sdk/provider-utils': 3.0.5(zod@3.25.76) zod: 3.25.76 '@ai-sdk/openai-compatible@1.0.31(zod@3.25.76)': @@ -11111,6 +11153,14 @@ snapshots: eventsource-parser: 3.0.6 zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.5(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 2.0.0 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 3.25.76 + zod-to-json-schema: 3.24.5(zod@3.25.76) + '@ai-sdk/provider-utils@4.0.10(zod@3.25.76)': dependencies: '@ai-sdk/provider': 3.0.5 @@ -11125,6 +11175,17 @@ snapshots: eventsource-parser: 3.0.6 zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.13(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 3.25.76 + + '@ai-sdk/provider@2.0.0': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/provider@2.0.1': dependencies: json-schema: 0.4.0 @@ -11137,6 +11198,10 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@3.0.7': + dependencies: + json-schema: 0.4.0 + '@alcalzone/ansi-tokenize@0.2.3': dependencies: ansi-styles: 6.2.3 @@ -15113,7 +15178,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: @@ -16469,6 +16534,8 @@ snapshots: dotenv@16.0.3: {} + dotenv@16.4.5: {} + dotenv@16.5.0: {} drizzle-kit@0.31.4: @@ -20622,6 +20689,15 @@ snapshots: safer-buffer@2.1.2: {} + sambanova-ai-provider@1.2.2(zod@3.25.76): + dependencies: + '@ai-sdk/openai-compatible': 1.0.11(zod@3.25.76) + '@ai-sdk/provider': 2.0.0 + '@ai-sdk/provider-utils': 3.0.5(zod@3.25.76) + dotenv: 16.4.5 + transitivePeerDependencies: + - zod + sanitize-filename@1.6.3: dependencies: truncate-utf8-bytes: 1.0.2 diff --git a/src/api/providers/__tests__/sambanova.spec.ts b/src/api/providers/__tests__/sambanova.spec.ts index 685cedf34c2..e11b6c20760 100644 --- a/src/api/providers/__tests__/sambanova.spec.ts +++ b/src/api/providers/__tests__/sambanova.spec.ts @@ -1,152 +1,628 @@ // npx vitest run src/api/providers/__tests__/sambanova.spec.ts -import OpenAI from "openai" -import { Anthropic } from "@anthropic-ai/sdk" +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types" - -import { SambaNovaHandler } from "../sambanova" - -vitest.mock("openai", () => { - const createMock = vitest.fn() +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("sambanova-ai-provider", () => ({ + createSambaNova: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "Meta-Llama-3.3-70B-Instruct", + provider: "sambanova", + })) + }), +})) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import { sambaNovaDefaultModelId, sambaNovaModels, type SambaNovaModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { SambaNovaHandler } from "../sambanova" + describe("SambaNovaHandler", () => { let handler: SambaNovaHandler - let mockCreate: any + let mockOptions: ApiHandlerOptions beforeEach(() => { - vitest.clearAllMocks() - mockCreate = (OpenAI as unknown as any)().chat.completions.create - handler = new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" }) + mockOptions = { + sambaNovaApiKey: "test-sambanova-api-key", + apiModelId: "Meta-Llama-3.3-70B-Instruct", + } + handler = new SambaNovaHandler(mockOptions) + vi.clearAllMocks() }) - it("should use the correct SambaNova base URL", () => { - new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.sambanova.ai/v1" })) - }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(SambaNovaHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) - it("should use the provided API key", () => { - const sambaNovaApiKey = "test-sambanova-api-key" - new SambaNovaHandler({ sambaNovaApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: sambaNovaApiKey })) + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new SambaNovaHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(sambaNovaDefaultModelId) + }) }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(sambaNovaDefaultModelId) - expect(model.info).toEqual(sambaNovaModels[sambaNovaDefaultModelId]) - }) + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new SambaNovaHandler({ + sambaNovaApiKey: "test-sambanova-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(sambaNovaDefaultModelId) + expect(model.info).toEqual(sambaNovaModels[sambaNovaDefaultModelId]) + }) - it("should return specified model when valid model is provided", () => { - const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct" - const handlerWithModel = new SambaNovaHandler({ - apiModelId: testModelId, - sambaNovaApiKey: "test-sambanova-api-key", + it("should return specified model when valid model is provided", () => { + const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct" + const handlerWithModel = new SambaNovaHandler({ + apiModelId: testModelId, + sambaNovaApiKey: "test-sambanova-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(sambaNovaModels[testModelId]) }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(sambaNovaModels[testModelId]) - }) - it("completePrompt method should return text from SambaNova API", async () => { - const expectedResponse = "This is a test response from SambaNova" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) + it("should return Meta-Llama-3.1-8B-Instruct model with correct configuration", () => { + const testModelId: SambaNovaModelId = "Meta-Llama-3.1-8B-Instruct" + const handlerWithModel = new SambaNovaHandler({ + apiModelId: testModelId, + sambaNovaApiKey: "test-sambanova-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBeDefined() + expect(model.info.contextWindow).toBeDefined() + }) - it("should handle errors in completePrompt", async () => { - const errorMessage = "SambaNova API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow( - `SambaNova completion error: ${errorMessage}`, - ) + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new SambaNovaHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + // Should use default model info + expect(model.info).toBe(sambaNovaModels[sambaNovaDefaultModelId]) + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) }) - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from SambaNova stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from SambaNova" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from SambaNova") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) + + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) + + // SambaNova provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + sambanova: { + promptCacheHitTokens: 30, + promptCacheMissTokens: 70, + }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) + expect(usageChunks[0].cacheWriteTokens).toBe(70) + }) + + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() + }) + + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new SambaNovaHandler({ + sambaNovaApiKey: "test-key", + apiModelId: "Meta-Llama-3.3-70B-Instruct", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, }), + ) + }) + + it("should use user-specified temperature over model and provider defaults", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new SambaNovaHandler({ + sambaNovaApiKey: "test-key", + apiModelId: "Meta-Llama-3.3-70B-Instruct", + modelTemperature: 0.7, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + // User-specified temperature should take precedence over everything + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) }) - it("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vitest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from SambaNova", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from SambaNova") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", }), - } + ) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) + await handler.completePrompt("Test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) }) - it("createMessage should pass correct parameters to SambaNova client", async () => { - const modelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct" - const modelInfo = sambaNovaModels[modelId] - const handlerWithModel = new SambaNovaHandler({ - apiModelId: modelId, - sambaNovaApiKey: "test-sambanova-api-key", + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestSambaNovaHandler extends SambaNovaHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestSambaNovaHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + sambanova: { + promptCacheHitTokens: 20, + promptCacheMissTokens: 80, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) + }) + + it("should handle missing cache metrics gracefully", () => { + class TestSambaNovaHandler extends SambaNovaHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestSambaNovaHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() }) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + it("should include reasoning tokens when provided", () => { + class TestSambaNovaHandler extends SambaNovaHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestSambaNovaHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.reasoningTokens).toBe(30) + }) + }) + + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, }, - }), + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // tool-call events should be ignored (only tool-input-start/delta/end are processed) + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) + }) + }) + + describe("error handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle AI SDK errors with handleAiSdkError", async () => { + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw new Error("API Error") } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("SambaNova: API Error") }) - const systemPrompt = "Test system prompt for SambaNova" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for SambaNova" }] + it("should preserve status codes in error handling", async () => { + const apiError = new Error("Rate limit exceeded") + ;(apiError as any).status = 429 - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw apiError + } - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0.7, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - undefined, - ) + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + try { + for await (const _ of stream) { + // consume stream + } + expect.fail("Should have thrown an error") + } catch (error: any) { + expect(error.message).toContain("SambaNova") + expect(error.status).toBe(429) + } + }) }) }) diff --git a/src/api/providers/sambanova.ts b/src/api/providers/sambanova.ts index a15bc125776..1cf57eb0245 100644 --- a/src/api/providers/sambanova.ts +++ b/src/api/providers/sambanova.ts @@ -1,19 +1,175 @@ -import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createSambaNova } from "sambanova-ai-provider" +import { streamText, generateText, ToolSet } from "ai" + +import { sambaNovaModels, sambaNovaDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const SAMBANOVA_DEFAULT_TEMPERATURE = 0.5 + +/** + * SambaNova provider using the dedicated sambanova-ai-provider package. + * Provides native support for various models including Llama models. + */ +export class SambaNovaHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType -export class SambaNovaHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "SambaNova", + super() + this.options = options + + // Create the SambaNova provider using AI SDK + this.provider = createSambaNova({ baseURL: "https://api.sambanova.ai/v1", - apiKey: options.sambaNovaApiKey, - defaultProviderModelId: sambaNovaDefaultModelId, - providerModels: sambaNovaModels, - defaultTemperature: 0.7, + apiKey: options.sambaNovaApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, + }) + } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? sambaNovaDefaultModelId + const info = sambaNovaModels[id as keyof typeof sambaNovaModels] || sambaNovaModels[sambaNovaDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: SAMBANOVA_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } + } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + sambanova?: { + promptCacheHitTokens?: number + promptCacheMissTokens?: number + } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from SambaNova's providerMetadata if available + const cacheReadTokens = providerMetadata?.sambanova?.promptCacheHitTokens ?? usage.details?.cachedInputTokens + const cacheWriteTokens = providerMetadata?.sambanova?.promptCacheMissTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? SAMBANOVA_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + } + + // Use streamText for streaming responses + const result = streamText(requestOptions) + + try { + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) + } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "SambaNova") + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? SAMBANOVA_DEFAULT_TEMPERATURE, }) + + return text } } diff --git a/src/package.json b/src/package.json index 98bd1d1b3ec..6292fd15940 100644 --- a/src/package.json +++ b/src/package.json @@ -455,6 +455,7 @@ "@ai-sdk/fireworks": "^2.0.26", "@ai-sdk/groq": "^3.0.19", "@ai-sdk/mistral": "^3.0.0", + "sambanova-ai-provider": "^1.2.2", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/vertex-sdk": "^0.7.0", From aa1e3c92f1996ab8d3c88b97cd0961fe4ee1b87b Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Mon, 2 Feb 2026 10:45:09 -0500 Subject: [PATCH 2/2] fix: flatten message content to string for SambaNova DeepSeek models DeepSeek models on SambaNova expect string content in messages, not array content. This fix adds: 1. flattenAiSdkMessagesToStringContent utility that converts text-only messages to string format 2. transform option to convertToAiSdkMessages for flexible message transformation The SambaNova handler passes the flattening function to convertToAiSdkMessages for models that don't support images. Fixes: 400 'Invalid content type' error for DeepSeek-R1-0528 model --- src/api/providers/__tests__/sambanova.spec.ts | 6 +- src/api/providers/sambanova.ts | 11 +- src/api/transform/__tests__/ai-sdk.spec.ts | 174 ++++++++++++++++++ src/api/transform/ai-sdk.ts | 92 ++++++++- 4 files changed, 276 insertions(+), 7 deletions(-) diff --git a/src/api/providers/__tests__/sambanova.spec.ts b/src/api/providers/__tests__/sambanova.spec.ts index e11b6c20760..51bc256b769 100644 --- a/src/api/providers/__tests__/sambanova.spec.ts +++ b/src/api/providers/__tests__/sambanova.spec.ts @@ -259,7 +259,7 @@ describe("SambaNovaHandler", () => { expect(usageChunks[0].cacheWriteTokens).toBeUndefined() }) - it("should pass correct temperature (0.5 default) to streamText", async () => { + it("should pass correct temperature (0.7 default) to streamText", async () => { async function* mockFullStream() { yield { type: "text-delta", text: "Test" } } @@ -282,7 +282,7 @@ describe("SambaNovaHandler", () => { expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - temperature: 0.5, + temperature: 0.7, }), ) }) @@ -369,7 +369,7 @@ describe("SambaNovaHandler", () => { expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - temperature: 0.5, + temperature: 0.7, }), ) }) diff --git a/src/api/providers/sambanova.ts b/src/api/providers/sambanova.ts index 1cf57eb0245..1e68dae33ff 100644 --- a/src/api/providers/sambanova.ts +++ b/src/api/providers/sambanova.ts @@ -12,6 +12,7 @@ import { processAiSdkStreamPart, mapToolChoice, handleAiSdkError, + flattenAiSdkMessagesToStringContent, } from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -20,7 +21,7 @@ import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -const SAMBANOVA_DEFAULT_TEMPERATURE = 0.5 +const SAMBANOVA_DEFAULT_TEMPERATURE = 0.7 /** * SambaNova provider using the dedicated sambanova-ai-provider package. @@ -112,11 +113,15 @@ export class SambaNovaHandler extends BaseProvider implements SingleCompletionHa messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { temperature } = this.getModel() + const { temperature, info } = this.getModel() const languageModel = this.getLanguageModel() // Convert messages to AI SDK format - const aiSdkMessages = convertToAiSdkMessages(messages) + // For models that don't support multi-part content (like DeepSeek), flatten messages to string content + // SambaNova's DeepSeek models expect string content, not array content + const aiSdkMessages = convertToAiSdkMessages(messages, { + transform: info.supportsImages ? undefined : flattenAiSdkMessagesToStringContent, + }) // Convert tools to OpenAI format first, then to AI SDK format const openAiTools = this.convertToolsForOpenAI(metadata?.tools) diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index bd87fd8eeb8..fb4e3b9e2f2 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -7,6 +7,7 @@ import { mapToolChoice, extractAiSdkErrorMessage, handleAiSdkError, + flattenAiSdkMessagesToStringContent, } from "../ai-sdk" vitest.mock("ai", () => ({ @@ -644,4 +645,177 @@ describe("AI SDK conversion utilities", () => { expect((result as any).cause).toBe(originalError) }) }) + + describe("flattenAiSdkMessagesToStringContent", () => { + it("should return messages unchanged if content is already a string", () => { + const messages = [ + { role: "user" as const, content: "Hello" }, + { role: "assistant" as const, content: "Hi there" }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toEqual(messages) + }) + + it("should flatten user messages with only text parts to string", () => { + const messages = [ + { + role: "user" as const, + content: [ + { type: "text" as const, text: "Hello" }, + { type: "text" as const, text: "World" }, + ], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toHaveLength(1) + expect(result[0].role).toBe("user") + expect(result[0].content).toBe("Hello\nWorld") + }) + + it("should flatten assistant messages with only text parts to string", () => { + const messages = [ + { + role: "assistant" as const, + content: [{ type: "text" as const, text: "I am an assistant" }], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toHaveLength(1) + expect(result[0].role).toBe("assistant") + expect(result[0].content).toBe("I am an assistant") + }) + + it("should not flatten user messages with image parts", () => { + const messages = [ + { + role: "user" as const, + content: [ + { type: "text" as const, text: "Look at this" }, + { type: "image" as const, image: "data:image/png;base64,abc123" }, + ], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toEqual(messages) + }) + + it("should not flatten assistant messages with tool calls", () => { + const messages = [ + { + role: "assistant" as const, + content: [ + { type: "text" as const, text: "Let me use a tool" }, + { + type: "tool-call" as const, + toolCallId: "123", + toolName: "read_file", + input: { path: "test.txt" }, + }, + ], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toEqual(messages) + }) + + it("should not flatten tool role messages", () => { + const messages = [ + { + role: "tool" as const, + content: [ + { + type: "tool-result" as const, + toolCallId: "123", + toolName: "test", + output: { type: "text" as const, value: "result" }, + }, + ], + }, + ] as any + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result).toEqual(messages) + }) + + it("should respect flattenUserMessages option", () => { + const messages = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: "Hello" }], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages, { flattenUserMessages: false }) + + expect(result).toEqual(messages) + }) + + it("should respect flattenAssistantMessages option", () => { + const messages = [ + { + role: "assistant" as const, + content: [{ type: "text" as const, text: "Hi" }], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages, { flattenAssistantMessages: false }) + + expect(result).toEqual(messages) + }) + + it("should handle mixed message types correctly", () => { + const messages = [ + { role: "user" as const, content: "Simple string" }, + { + role: "user" as const, + content: [{ type: "text" as const, text: "Text parts" }], + }, + { + role: "assistant" as const, + content: [{ type: "text" as const, text: "Assistant text" }], + }, + { + role: "assistant" as const, + content: [ + { type: "text" as const, text: "With tool" }, + { type: "tool-call" as const, toolCallId: "456", toolName: "test", input: {} }, + ], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result[0].content).toBe("Simple string") // unchanged + expect(result[1].content).toBe("Text parts") // flattened + expect(result[2].content).toBe("Assistant text") // flattened + expect(result[3]).toEqual(messages[3]) // unchanged (has tool call) + }) + + it("should handle empty text parts", () => { + const messages = [ + { + role: "user" as const, + content: [ + { type: "text" as const, text: "" }, + { type: "text" as const, text: "Hello" }, + ], + }, + ] + + const result = flattenAiSdkMessagesToStringContent(messages) + + expect(result[0].content).toBe("\nHello") + }) + }) }) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index ebbf1a8661f..c6f37be694d 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -8,14 +8,29 @@ import OpenAI from "openai" import { tool as createTool, jsonSchema, type ModelMessage, type TextStreamPart } from "ai" import type { ApiStreamChunk } from "./stream" +/** + * Options for converting Anthropic messages to AI SDK format. + */ +export interface ConvertToAiSdkMessagesOptions { + /** + * Optional function to transform the converted messages. + * Useful for transformations like flattening message content for models that require string content. + */ + transform?: (messages: ModelMessage[]) => ModelMessage[] +} + /** * Convert Anthropic messages to AI SDK ModelMessage format. * Handles text, images, tool uses, and tool results. * * @param messages - Array of Anthropic message parameters + * @param options - Optional conversion options including post-processing function * @returns Array of AI SDK ModelMessage objects */ -export function convertToAiSdkMessages(messages: Anthropic.Messages.MessageParam[]): ModelMessage[] { +export function convertToAiSdkMessages( + messages: Anthropic.Messages.MessageParam[], + options?: ConvertToAiSdkMessagesOptions, +): ModelMessage[] { const modelMessages: ModelMessage[] = [] // First pass: build a map of tool call IDs to tool names from assistant messages @@ -149,9 +164,84 @@ export function convertToAiSdkMessages(messages: Anthropic.Messages.MessageParam } } + // Apply transform if provided + if (options?.transform) { + return options.transform(modelMessages) + } + return modelMessages } +/** + * Options for flattening AI SDK messages. + */ +export interface FlattenMessagesOptions { + /** + * If true, flattens user messages with only text parts to string content. + * Default: true + */ + flattenUserMessages?: boolean + /** + * If true, flattens assistant messages with only text (no tool calls) to string content. + * Default: true + */ + flattenAssistantMessages?: boolean +} + +/** + * Flatten AI SDK messages to use string content where possible. + * Some models (like DeepSeek on SambaNova) require string content instead of array content. + * This function converts messages that contain only text parts to use simple string content. + * + * @param messages - Array of AI SDK ModelMessage objects + * @param options - Options for controlling which message types to flatten + * @returns Array of AI SDK ModelMessage objects with flattened content where applicable + */ +export function flattenAiSdkMessagesToStringContent( + messages: ModelMessage[], + options: FlattenMessagesOptions = {}, +): ModelMessage[] { + const { flattenUserMessages = true, flattenAssistantMessages = true } = options + + return messages.map((message) => { + // Skip if content is already a string + if (typeof message.content === "string") { + return message + } + + // Handle user messages + if (message.role === "user" && flattenUserMessages && Array.isArray(message.content)) { + const parts = message.content as Array<{ type: string; text?: string }> + // Only flatten if all parts are text + const allText = parts.every((part) => part.type === "text") + if (allText && parts.length > 0) { + const textContent = parts.map((part) => part.text || "").join("\n") + return { + ...message, + content: textContent, + } + } + } + + // Handle assistant messages + if (message.role === "assistant" && flattenAssistantMessages && Array.isArray(message.content)) { + const parts = message.content as Array<{ type: string; text?: string }> + // Only flatten if all parts are text (no tool calls) + const allText = parts.every((part) => part.type === "text") + if (allText && parts.length > 0) { + const textContent = parts.map((part) => part.text || "").join("\n") + return { + ...message, + content: textContent, + } + } + } + + // Return unchanged for tool role and messages with non-text content + return message + }) +} + /** * Convert OpenAI-style function tool definitions to AI SDK tool format. *