From 9a5e70b6b5f0d4778d0382e576359e076e8db23e Mon Sep 17 00:00:00 2001 From: Hannes Rudolph Date: Mon, 22 Dec 2025 21:41:46 -0700 Subject: [PATCH 1/2] feat: standardize model selectors across all providers Replace various model selection UI patterns with consistent ModelPicker component: - Replace Select dropdowns for static providers (Anthropic, Bedrock, etc.) - Replace text input + radio buttons for Ollama and LM Studio - Replace Select for VSCodeLM with transform functions for vendor/family object - Add providerModelConfig.ts with service configuration helpers - Add unit tests for new utility functions --- .../src/components/settings/ApiOptions.tsx | 111 ++++--------- .../src/components/settings/ModelPicker.tsx | 42 ++++- .../ApiOptions.provider-filtering.spec.tsx | 11 ++ .../settings/providers/LMStudio.tsx | 142 +++++----------- .../components/settings/providers/Ollama.tsx | 62 +++---- .../settings/providers/VSCodeLM.tsx | 95 ++++++----- .../__tests__/providerModelConfig.spec.ts | 155 ++++++++++++++++++ .../settings/utils/providerModelConfig.ts | 146 +++++++++++++++++ 8 files changed, 499 insertions(+), 265 deletions(-) create mode 100644 webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts create mode 100644 webview-ui/src/components/settings/utils/providerModelConfig.ts diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 00092fea9d1..a363558592f 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -40,6 +40,13 @@ import { minimaxDefaultModelId, } from "@roo-code/types" +import { + getProviderServiceConfig, + getDefaultModelIdForProvider, + getStaticModelsForProvider, + shouldUseGenericModelPicker, +} from "./utils/providerModelConfig" + import { vscode } from "@src/utils/vscode" import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate" import { useAppTranslation } from "@src/i18n/TranslationContext" @@ -102,7 +109,7 @@ import { import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants" import { inputEventTransform, noTransform } from "./transforms" -import { ModelInfoView } from "./ModelInfoView" +import { ModelPicker } from "./ModelPicker" import { ApiErrorMessage } from "./ApiErrorMessage" import { ThinkingBudget } from "./ThinkingBudget" import { Verbosity } from "./Verbosity" @@ -171,7 +178,6 @@ const ApiOptions = ({ [customHeaders, apiConfiguration?.openAiHeaders, setApiConfigurationField], ) - const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false) const [isAdvancedSettingsOpen, setIsAdvancedSettingsOpen] = useState(false) const handleInputChange = useCallback( @@ -270,32 +276,6 @@ const ApiOptions = ({ setErrorMessage(apiValidationResult) }, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage]) - const selectedProviderModels = useMemo(() => { - const models = MODELS_BY_PROVIDER[selectedProvider] - - if (!models) return [] - - const filteredModels = filterModels(models, selectedProvider, organizationAllowList) - - // Include the currently selected model even if deprecated (so users can see what they have selected) - // But filter out other deprecated models from being newly selectable - const availableModels = filteredModels - ? Object.entries(filteredModels) - .filter(([modelId, modelInfo]) => { - // Always include the currently selected model - if (modelId === selectedModelId) return true - // Filter out deprecated models that aren't currently selected - return !modelInfo.deprecated - }) - .map(([modelId]) => ({ - value: modelId, - label: modelId, - })) - : [] - - return availableModels - }, [selectedProvider, organizationAllowList, selectedModelId]) - const onProviderChange = useCallback( (value: ProviderName) => { setApiConfigurationField("apiProvider", value) @@ -767,47 +747,33 @@ const ApiOptions = ({ )} - {/* Skip generic model picker for claude-code since it has its own in ClaudeCode.tsx */} - {selectedProviderModels.length > 0 && selectedProvider !== "claude-code" && ( + {/* Generic model picker for providers with static models */} + {shouldUseGenericModelPicker(selectedProvider) && ( <> -
- - -
+ { + // Clear custom ARN if not using custom ARN option (Bedrock) + if (modelId !== "custom-arn" && selectedProvider === "bedrock") { + setApiConfigurationField("awsCustomArn", "") + } - {/* Show error if a deprecated model is selected */} - {selectedModelInfo?.deprecated && ( - - )} + // Clear reasoning effort when switching models to allow the new model's default to take effect + // This is especially important for GPT-5 models which default to "medium" + if (selectedProvider === "openai-native") { + setApiConfigurationField("reasoningEffort", undefined) + } + }} + /> {selectedProvider === "bedrock" && selectedModelId === "custom-arn" && ( )} - - {/* Only show model info if not deprecated */} - {!selectedModelInfo?.deprecated && ( - - )} )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 4fe4c02dda5..c1bfa34beb8 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -37,6 +37,10 @@ type ModelIdKey = keyof Pick< | "ioIntelligenceModelId" | "vercelAiGatewayModelId" | "apiModelId" + | "ollamaModelId" + | "lmStudioModelId" + | "lmStudioDraftModelId" + | "vsCodeLmModelSelector" > interface ModelPickerProps { @@ -55,6 +59,14 @@ interface ModelPickerProps { errorMessage?: string simplifySettings?: boolean hidePricing?: boolean + /** Label for the model picker field - defaults to "Model" */ + label?: string + /** Transform model ID string to the value stored in configuration (for compound types like VSCodeLM selector) */ + valueTransform?: (modelId: string) => unknown + /** Transform stored configuration value back to display string */ + displayTransform?: (value: unknown) => string + /** Callback when model changes - useful for side effects like clearing related fields */ + onModelChange?: (modelId: string) => void } export const ModelPicker = ({ @@ -69,6 +81,10 @@ export const ModelPicker = ({ errorMessage, simplifySettings, hidePricing, + label, + valueTransform, + displayTransform, + onModelChange, }: ModelPickerProps) => { const { t } = useAppTranslation() @@ -81,6 +97,16 @@ export const ModelPicker = ({ const { id: selectedModelId, info: selectedModelInfo } = useSelectedModel(apiConfiguration) + // Get the display value for the current selection + // If displayTransform is provided, use it to convert the stored value to a display string + const displayValue = useMemo(() => { + if (displayTransform) { + const storedValue = apiConfiguration[modelIdKey] + return storedValue ? displayTransform(storedValue) : undefined + } + return selectedModelId + }, [displayTransform, apiConfiguration, modelIdKey, selectedModelId]) + const modelIds = useMemo(() => { const filteredModels = filterModels(models, apiConfiguration.apiProvider, organizationAllowList) @@ -113,7 +139,13 @@ export const ModelPicker = ({ } setOpen(false) - setApiConfigurationField(modelIdKey, modelId) + + // Apply value transform if provided (e.g., for VSCodeLM selector) + const valueToStore = valueTransform ? valueTransform(modelId) : modelId + setApiConfigurationField(modelIdKey, valueToStore as ProviderSettings[ModelIdKey]) + + // Call the optional change callback + onModelChange?.(modelId) // Clear any existing timeout if (selectTimeoutRef.current) { @@ -123,7 +155,7 @@ export const ModelPicker = ({ // Delay to ensure the popover is closed before setting the search value. selectTimeoutRef.current = setTimeout(() => setSearchValue(""), 100) }, - [modelIdKey, setApiConfigurationField], + [modelIdKey, setApiConfigurationField, valueTransform, onModelChange], ) const onOpenChange = useCallback((open: boolean) => { @@ -173,7 +205,7 @@ export const ModelPicker = ({ return ( <>
- + @@ -227,7 +259,7 @@ export const ModelPicker = ({ diff --git a/webview-ui/src/components/settings/__tests__/ApiOptions.provider-filtering.spec.tsx b/webview-ui/src/components/settings/__tests__/ApiOptions.provider-filtering.spec.tsx index 62b77bb733b..f2cdad0a834 100644 --- a/webview-ui/src/components/settings/__tests__/ApiOptions.provider-filtering.spec.tsx +++ b/webview-ui/src/components/settings/__tests__/ApiOptions.provider-filtering.spec.tsx @@ -80,6 +80,17 @@ vi.mock("@src/components/ui", () => ({ CollapsibleContent: ({ children }: any) =>
{children}
, Slider: ({ children, ...props }: any) =>
{children}
, Button: ({ children, ...props }: any) => , + // Add Popover components for ModelPicker + Popover: ({ children }: any) =>
{children}
, + PopoverTrigger: ({ children }: any) =>
{children}
, + PopoverContent: ({ children }: any) =>
{children}
, + // Add Command components for ModelPicker + Command: ({ children }: any) =>
{children}
, + CommandInput: ({ ...props }: any) => , + CommandList: ({ children }: any) =>
{children}
, + CommandEmpty: ({ children }: any) =>
{children}
, + CommandGroup: ({ children }: any) =>
{children}
, + CommandItem: ({ children, ...props }: any) =>
{children}
, })) describe("ApiOptions Provider Filtering", () => { diff --git a/webview-ui/src/components/settings/providers/LMStudio.tsx b/webview-ui/src/components/settings/providers/LMStudio.tsx index 04fb53aa272..055819b4722 100644 --- a/webview-ui/src/components/settings/providers/LMStudio.tsx +++ b/webview-ui/src/components/settings/providers/LMStudio.tsx @@ -2,7 +2,7 @@ import { useCallback, useState, useMemo, useEffect } from "react" import { useEvent } from "react-use" import { Trans } from "react-i18next" import { Checkbox } from "vscrui" -import { VSCodeLink, VSCodeRadio, VSCodeRadioGroup, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import type { ProviderSettings } from "@roo-code/types" @@ -12,6 +12,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" import { ModelRecord } from "@roo/api" type LMStudioProps = { @@ -59,46 +60,50 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi }, []) // Check if the selected model exists in the fetched models - const modelNotAvailable = useMemo(() => { + const modelNotAvailableError = useMemo(() => { const selectedModel = apiConfiguration?.lmStudioModelId - if (!selectedModel) return false + if (!selectedModel) return undefined // Check if model exists in local LM Studio models if (Object.keys(lmStudioModels).length > 0 && selectedModel in lmStudioModels) { - return false // Model is available locally + return undefined // Model is available locally } // If we have router models data for LM Studio if (routerModels.data?.lmstudio) { const availableModels = Object.keys(routerModels.data.lmstudio) // Show warning if model is not in the list (regardless of how many models there are) - return !availableModels.includes(selectedModel) + if (!availableModels.includes(selectedModel)) { + return t("settings:validation.modelAvailability", { modelId: selectedModel }) + } } // If neither source has loaded yet, don't show warning - return false - }, [apiConfiguration?.lmStudioModelId, routerModels.data, lmStudioModels]) + return undefined + }, [apiConfiguration?.lmStudioModelId, routerModels.data, lmStudioModels, t]) // Check if the draft model exists - const draftModelNotAvailable = useMemo(() => { + const draftModelNotAvailableError = useMemo(() => { const draftModel = apiConfiguration?.lmStudioDraftModelId - if (!draftModel) return false + if (!draftModel) return undefined // Check if model exists in local LM Studio models if (Object.keys(lmStudioModels).length > 0 && draftModel in lmStudioModels) { - return false // Model is available locally + return undefined // Model is available locally } // If we have router models data for LM Studio if (routerModels.data?.lmstudio) { const availableModels = Object.keys(routerModels.data.lmstudio) // Show warning if model is not in the list (regardless of how many models there are) - return !availableModels.includes(draftModel) + if (!availableModels.includes(draftModel)) { + return t("settings:validation.modelAvailability", { modelId: draftModel }) + } } // If neither source has loaded yet, don't show warning - return false - }, [apiConfiguration?.lmStudioDraftModelId, routerModels.data, lmStudioModels]) + return undefined + }, [apiConfiguration?.lmStudioDraftModelId, routerModels.data, lmStudioModels, t]) return ( <> @@ -110,38 +115,17 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi className="w-full"> - - - - {modelNotAvailable && ( -
-
-
-
- {t("settings:validation.modelAvailability", { modelId: apiConfiguration?.lmStudioModelId })} -
-
-
- )} - {Object.keys(lmStudioModels).length > 0 && ( - - {Object.keys(lmStudioModels).map((model) => ( - - {model} - - ))} - - )} + { @@ -151,61 +135,21 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi {apiConfiguration?.lmStudioSpeculativeDecodingEnabled && ( <> -
- - - -
- {t("settings:providers.lmStudio.draftModelDesc")} -
- {draftModelNotAvailable && ( -
-
-
-
- {t("settings:validation.modelAvailability", { - modelId: apiConfiguration?.lmStudioDraftModelId, - })} -
-
-
- )} + +
+ {t("settings:providers.lmStudio.draftModelDesc")}
- {Object.keys(lmStudioModels).length > 0 && ( - <> -
{t("settings:providers.lmStudio.selectDraftModel")}
- - {Object.keys(lmStudioModels).map((model) => ( - - {model} - - ))} - - {Object.keys(lmStudioModels).length === 0 && ( -
- {t("settings:providers.lmStudio.noModelsFound")} -
- )} - - )} )}
diff --git a/webview-ui/src/components/settings/providers/Ollama.tsx b/webview-ui/src/components/settings/providers/Ollama.tsx index 615d3be4098..fa64fd2b010 100644 --- a/webview-ui/src/components/settings/providers/Ollama.tsx +++ b/webview-ui/src/components/settings/providers/Ollama.tsx @@ -1,6 +1,6 @@ import { useState, useCallback, useMemo, useEffect } from "react" import { useEvent } from "react-use" -import { VSCodeTextField, VSCodeRadioGroup, VSCodeRadio } from "@vscode/webview-ui-toolkit/react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import type { ProviderSettings } from "@roo-code/types" @@ -11,6 +11,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" import { ModelRecord } from "@roo/api" type OllamaProps = { @@ -57,25 +58,27 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro }, []) // Check if the selected model exists in the fetched models - const modelNotAvailable = useMemo(() => { + const modelNotAvailableError = useMemo(() => { const selectedModel = apiConfiguration?.ollamaModelId - if (!selectedModel) return false + if (!selectedModel) return undefined // Check if model exists in local ollama models if (Object.keys(ollamaModels).length > 0 && selectedModel in ollamaModels) { - return false // Model is available locally + return undefined // Model is available locally } // If we have router models data for Ollama if (routerModels.data?.ollama) { const availableModels = Object.keys(routerModels.data.ollama) // Show warning if model is not in the list (regardless of how many models there are) - return !availableModels.includes(selectedModel) + if (!availableModels.includes(selectedModel)) { + return t("settings:validation.modelAvailability", { modelId: selectedModel }) + } } // If neither source has loaded yet, don't show warning - return false - }, [apiConfiguration?.ollamaModelId, routerModels.data, ollamaModels]) + return undefined + }, [apiConfiguration?.ollamaModelId, routerModels.data, ollamaModels, t]) return ( <> @@ -100,40 +103,21 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
)} - - - - {modelNotAvailable && ( -
-
-
-
- {t("settings:validation.modelAvailability", { modelId: apiConfiguration?.ollamaModelId })} -
-
-
- )} - {Object.keys(ollamaModels).length > 0 && ( - - {Object.keys(ollamaModels).map((model) => ( - - {model} - - ))} - - )} + { - const value = e.target?.value + onInput={(e) => { + const value = (e.target as HTMLInputElement)?.value if (value === "") { setApiConfigurationField("ollamaNumCtx", undefined) } else { diff --git a/webview-ui/src/components/settings/providers/VSCodeLM.tsx b/webview-ui/src/components/settings/providers/VSCodeLM.tsx index a2097badf61..73d2bdd22fc 100644 --- a/webview-ui/src/components/settings/providers/VSCodeLM.tsx +++ b/webview-ui/src/components/settings/providers/VSCodeLM.tsx @@ -1,15 +1,14 @@ -import { useState, useCallback } from "react" +import { useState, useCallback, useMemo } from "react" import { useEvent } from "react-use" import { LanguageModelChatSelector } from "vscode" -import type { ProviderSettings } from "@roo-code/types" +import type { ProviderSettings, ModelInfo } from "@roo-code/types" import { ExtensionMessage } from "@roo/ExtensionMessage" import { useAppTranslation } from "@src/i18n/TranslationContext" -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@src/components/ui" -import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" type VSCodeLMProps = { apiConfiguration: ProviderSettings @@ -21,17 +20,6 @@ export const VSCodeLM = ({ apiConfiguration, setApiConfigurationField }: VSCodeL const [vsCodeLmModels, setVsCodeLmModels] = useState([]) - const handleInputChange = useCallback( - ( - field: K, - transform: (event: E) => ProviderSettings[K] = inputEventTransform, - ) => - (event: E | Event) => { - setApiConfigurationField(field, transform(event as E)) - }, - [setApiConfigurationField], - ) - const onMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data @@ -47,40 +35,59 @@ export const VSCodeLM = ({ apiConfiguration, setApiConfigurationField }: VSCodeL useEvent("message", onMessage) + // Convert VSCode LM models array to Record format for ModelPicker + const modelsRecord = useMemo((): Record => { + return vsCodeLmModels.reduce( + (acc, model) => { + const modelId = `${model.vendor}/${model.family}` + acc[modelId] = { + maxTokens: 0, + contextWindow: 0, + supportsPromptCache: false, + description: `${model.vendor} - ${model.family}`, + } + return acc + }, + {} as Record, + ) + }, [vsCodeLmModels]) + + // Transform string model ID to { vendor, family } object for storage + const valueTransform = useCallback((modelId: string) => { + const [vendor, family] = modelId.split("/") + return { vendor, family } + }, []) + + // Transform stored { vendor, family } object back to display string + const displayTransform = useCallback((value: unknown) => { + if (!value) return "" + const selector = value as { vendor?: string; family?: string } + return selector.vendor && selector.family ? `${selector.vendor}/${selector.family}` : "" + }, []) + return ( <> -
- - {vsCodeLmModels.length > 0 ? ( - - ) : ( + {vsCodeLmModels.length > 0 ? ( + + ) : ( +
+
{t("settings:providers.vscodeLmDescription")}
- )} -
+
+ )}
{t("settings:providers.vscodeLmWarning")}
) diff --git a/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts b/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts new file mode 100644 index 00000000000..931796d2b02 --- /dev/null +++ b/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts @@ -0,0 +1,155 @@ +import { + PROVIDER_SERVICE_CONFIG, + PROVIDER_DEFAULT_MODEL_IDS, + getProviderServiceConfig, + getDefaultModelIdForProvider, + getStaticModelsForProvider, + isStaticModelProvider, + PROVIDERS_WITH_CUSTOM_MODEL_UI, + shouldUseGenericModelPicker, +} from "../providerModelConfig" + +describe("providerModelConfig", () => { + describe("PROVIDER_SERVICE_CONFIG", () => { + it("contains service config for anthropic", () => { + expect(PROVIDER_SERVICE_CONFIG.anthropic).toEqual({ + serviceName: "Anthropic", + serviceUrl: "https://console.anthropic.com", + }) + }) + + it("contains service config for bedrock", () => { + expect(PROVIDER_SERVICE_CONFIG.bedrock).toEqual({ + serviceName: "Amazon Bedrock", + serviceUrl: "https://aws.amazon.com/bedrock", + }) + }) + + it("contains service config for ollama", () => { + expect(PROVIDER_SERVICE_CONFIG.ollama).toEqual({ + serviceName: "Ollama", + serviceUrl: "https://ollama.ai", + }) + }) + + it("contains service config for lmstudio", () => { + expect(PROVIDER_SERVICE_CONFIG.lmstudio).toEqual({ + serviceName: "LM Studio", + serviceUrl: "https://lmstudio.ai/docs", + }) + }) + + it("contains service config for vscode-lm", () => { + expect(PROVIDER_SERVICE_CONFIG["vscode-lm"]).toEqual({ + serviceName: "VS Code LM", + serviceUrl: "https://code.visualstudio.com/api/extension-guides/language-model", + }) + }) + }) + + describe("getProviderServiceConfig", () => { + it("returns correct config for known provider", () => { + const config = getProviderServiceConfig("gemini") + expect(config.serviceName).toBe("Google Gemini") + expect(config.serviceUrl).toBe("https://ai.google.dev") + }) + + it("returns fallback config for unknown provider", () => { + const config = getProviderServiceConfig("unknown-provider" as any) + expect(config.serviceName).toBe("unknown-provider") + expect(config.serviceUrl).toBe("") + }) + }) + + describe("PROVIDER_DEFAULT_MODEL_IDS", () => { + it("contains default model IDs for static providers", () => { + expect(PROVIDER_DEFAULT_MODEL_IDS.anthropic).toBeDefined() + expect(PROVIDER_DEFAULT_MODEL_IDS.bedrock).toBeDefined() + expect(PROVIDER_DEFAULT_MODEL_IDS.gemini).toBeDefined() + expect(PROVIDER_DEFAULT_MODEL_IDS["openai-native"]).toBeDefined() + }) + }) + + describe("getDefaultModelIdForProvider", () => { + it("returns default model ID for known provider", () => { + const defaultId = getDefaultModelIdForProvider("anthropic") + expect(defaultId).toBeDefined() + expect(typeof defaultId).toBe("string") + expect(defaultId.length).toBeGreaterThan(0) + }) + + it("returns empty string for unknown provider", () => { + const defaultId = getDefaultModelIdForProvider("unknown" as any) + expect(defaultId).toBe("") + }) + }) + + describe("getStaticModelsForProvider", () => { + it("returns models for anthropic provider", () => { + const models = getStaticModelsForProvider("anthropic") + expect(Object.keys(models).length).toBeGreaterThan(0) + }) + + it("adds custom-arn option for bedrock provider", () => { + const models = getStaticModelsForProvider("bedrock", "Use Custom ARN") + expect(models["custom-arn"]).toBeDefined() + expect(models["custom-arn"].description).toBe("Use Custom ARN") + }) + + it("returns empty object for providers without static models", () => { + const models = getStaticModelsForProvider("openrouter") + expect(Object.keys(models).length).toBe(0) + }) + }) + + describe("isStaticModelProvider", () => { + it("returns true for providers with static models", () => { + expect(isStaticModelProvider("anthropic")).toBe(true) + expect(isStaticModelProvider("bedrock")).toBe(true) + expect(isStaticModelProvider("gemini")).toBe(true) + expect(isStaticModelProvider("openai-native")).toBe(true) + }) + + it("returns false for providers without static models", () => { + expect(isStaticModelProvider("openrouter")).toBe(false) + expect(isStaticModelProvider("ollama")).toBe(false) + expect(isStaticModelProvider("lmstudio")).toBe(false) + }) + }) + + describe("PROVIDERS_WITH_CUSTOM_MODEL_UI", () => { + it("includes providers that have their own model selection UI", () => { + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).toContain("openrouter") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).toContain("ollama") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).toContain("lmstudio") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).toContain("vscode-lm") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).toContain("claude-code") + }) + + it("does not include static providers using generic picker", () => { + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).not.toContain("anthropic") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).not.toContain("gemini") + expect(PROVIDERS_WITH_CUSTOM_MODEL_UI).not.toContain("bedrock") + }) + }) + + describe("shouldUseGenericModelPicker", () => { + it("returns true for static providers without custom UI", () => { + expect(shouldUseGenericModelPicker("anthropic")).toBe(true) + expect(shouldUseGenericModelPicker("bedrock")).toBe(true) + expect(shouldUseGenericModelPicker("gemini")).toBe(true) + expect(shouldUseGenericModelPicker("deepseek")).toBe(true) + }) + + it("returns false for providers with custom model UI", () => { + expect(shouldUseGenericModelPicker("openrouter")).toBe(false) + expect(shouldUseGenericModelPicker("ollama")).toBe(false) + expect(shouldUseGenericModelPicker("lmstudio")).toBe(false) + expect(shouldUseGenericModelPicker("vscode-lm")).toBe(false) + }) + + it("returns false for providers without static models", () => { + expect(shouldUseGenericModelPicker("openai")).toBe(false) + }) + }) +}) diff --git a/webview-ui/src/components/settings/utils/providerModelConfig.ts b/webview-ui/src/components/settings/utils/providerModelConfig.ts new file mode 100644 index 00000000000..e71081a7a1a --- /dev/null +++ b/webview-ui/src/components/settings/utils/providerModelConfig.ts @@ -0,0 +1,146 @@ +import type { ProviderName, ModelInfo } from "@roo-code/types" +import { + anthropicDefaultModelId, + bedrockDefaultModelId, + cerebrasDefaultModelId, + deepSeekDefaultModelId, + doubaoDefaultModelId, + moonshotDefaultModelId, + geminiDefaultModelId, + mistralDefaultModelId, + openAiNativeDefaultModelId, + qwenCodeDefaultModelId, + vertexDefaultModelId, + xaiDefaultModelId, + groqDefaultModelId, + sambaNovaDefaultModelId, + internationalZAiDefaultModelId, + fireworksDefaultModelId, + featherlessDefaultModelId, + minimaxDefaultModelId, + basetenDefaultModelId, +} from "@roo-code/types" + +import { MODELS_BY_PROVIDER } from "../constants" + +export interface ProviderServiceConfig { + serviceName: string + serviceUrl: string +} + +export const PROVIDER_SERVICE_CONFIG: Partial> = { + anthropic: { serviceName: "Anthropic", serviceUrl: "https://console.anthropic.com" }, + bedrock: { serviceName: "Amazon Bedrock", serviceUrl: "https://aws.amazon.com/bedrock" }, + cerebras: { serviceName: "Cerebras", serviceUrl: "https://cerebras.ai" }, + deepseek: { serviceName: "DeepSeek", serviceUrl: "https://platform.deepseek.com" }, + doubao: { serviceName: "Doubao", serviceUrl: "https://www.volcengine.com/product/doubao" }, + moonshot: { serviceName: "Moonshot", serviceUrl: "https://platform.moonshot.cn" }, + gemini: { serviceName: "Google Gemini", serviceUrl: "https://ai.google.dev" }, + mistral: { serviceName: "Mistral", serviceUrl: "https://console.mistral.ai" }, + "openai-native": { serviceName: "OpenAI", serviceUrl: "https://platform.openai.com" }, + "qwen-code": { serviceName: "Qwen Code", serviceUrl: "https://dashscope.console.aliyun.com" }, + vertex: { serviceName: "GCP Vertex AI", serviceUrl: "https://console.cloud.google.com/vertex-ai" }, + xai: { serviceName: "xAI", serviceUrl: "https://x.ai" }, + groq: { serviceName: "Groq", serviceUrl: "https://console.groq.com" }, + sambanova: { serviceName: "SambaNova", serviceUrl: "https://sambanova.ai" }, + zai: { serviceName: "Z.ai", serviceUrl: "https://z.ai" }, + fireworks: { serviceName: "Fireworks AI", serviceUrl: "https://fireworks.ai" }, + featherless: { serviceName: "Featherless AI", serviceUrl: "https://featherless.ai" }, + minimax: { serviceName: "MiniMax", serviceUrl: "https://minimax.chat" }, + baseten: { serviceName: "Baseten", serviceUrl: "https://baseten.co" }, + ollama: { serviceName: "Ollama", serviceUrl: "https://ollama.ai" }, + lmstudio: { serviceName: "LM Studio", serviceUrl: "https://lmstudio.ai/docs" }, + "vscode-lm": { + serviceName: "VS Code LM", + serviceUrl: "https://code.visualstudio.com/api/extension-guides/language-model", + }, +} + +export const PROVIDER_DEFAULT_MODEL_IDS: Partial> = { + anthropic: anthropicDefaultModelId, + bedrock: bedrockDefaultModelId, + cerebras: cerebrasDefaultModelId, + deepseek: deepSeekDefaultModelId, + doubao: doubaoDefaultModelId, + moonshot: moonshotDefaultModelId, + gemini: geminiDefaultModelId, + mistral: mistralDefaultModelId, + "openai-native": openAiNativeDefaultModelId, + "qwen-code": qwenCodeDefaultModelId, + vertex: vertexDefaultModelId, + xai: xaiDefaultModelId, + groq: groqDefaultModelId, + sambanova: sambaNovaDefaultModelId, + zai: internationalZAiDefaultModelId, + fireworks: fireworksDefaultModelId, + featherless: featherlessDefaultModelId, + minimax: minimaxDefaultModelId, + baseten: basetenDefaultModelId, +} + +export const getProviderServiceConfig = (provider: ProviderName): ProviderServiceConfig => { + return PROVIDER_SERVICE_CONFIG[provider] ?? { serviceName: provider, serviceUrl: "" } +} + +export const getDefaultModelIdForProvider = (provider: ProviderName): string => { + return PROVIDER_DEFAULT_MODEL_IDS[provider] ?? "" +} + +export const getStaticModelsForProvider = ( + provider: ProviderName, + customArnLabel?: string, +): Record => { + const models = MODELS_BY_PROVIDER[provider] ?? {} + + // Add custom-arn option for Bedrock + if (provider === "bedrock") { + return { + ...models, + "custom-arn": { + maxTokens: 0, + contextWindow: 0, + supportsPromptCache: false, + description: customArnLabel ?? "Use Custom ARN", + }, + } + } + + return models +} + +/** + * Checks if a provider uses static models from MODELS_BY_PROVIDER + */ +export const isStaticModelProvider = (provider: ProviderName): boolean => { + return provider in MODELS_BY_PROVIDER +} + +/** + * List of providers that have their own custom model selection UI + * and should not use the generic ModelPicker in ApiOptions + */ +export const PROVIDERS_WITH_CUSTOM_MODEL_UI: ProviderName[] = [ + "openrouter", + "requesty", + "unbound", + "deepinfra", + "claude-code", + "openai", // OpenAI Compatible + "litellm", + "io-intelligence", + "vercel-ai-gateway", + "roo", + "chutes", + "ollama", + "lmstudio", + "vscode-lm", + "huggingface", + "human-relay", +] + +/** + * Checks if a provider should use the generic ModelPicker + */ +export const shouldUseGenericModelPicker = (provider: ProviderName): boolean => { + return isStaticModelProvider(provider) && !PROVIDERS_WITH_CUSTOM_MODEL_UI.includes(provider) +} From b368a1e0afebd4911f5a2d20a68a3586be5a2881 Mon Sep 17 00:00:00 2001 From: Hannes Rudolph Date: Mon, 22 Dec 2025 22:03:43 -0700 Subject: [PATCH 2/2] Update webview-ui/src/components/settings/ApiOptions.tsx Co-authored-by: roomote[bot] <219738659+roomote[bot]@users.noreply.github.com> --- webview-ui/src/components/settings/ApiOptions.tsx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index a363558592f..747cd9b1493 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -753,7 +753,13 @@ const ApiOptions = ({