diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index 8279b6bc..87a1e26a 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -35,8 +35,9 @@ func (a *AIClientV2) GetOpenAIClient(llmConfig *models.LLMProviderConfig) *opena if Endpoint == "" { if APIKey != "" { // User provided their own API key, use the OpenAI-compatible endpoint - Endpoint = a.cfg.InferenceBaseURL + "/openai" + Endpoint = a.cfg.OpenAIBaseURL // standard openai base url } else { + // suffix needed for cloudflare gateway Endpoint = a.cfg.InferenceBaseURL + "/openrouter" } } @@ -63,11 +64,37 @@ func NewAIClientV2( logger *logger.Logger, ) *AIClientV2 { database := db.Database("paperdebugger") - oaiClient := openai.NewClient( - option.WithBaseURL(cfg.InferenceBaseURL+"/openrouter"), - option.WithAPIKey(cfg.InferenceAPIKey), + + llmProvider := &models.LLMProviderConfig{ + APIKey: cfg.OpenAIAPIKey, + } + + var baseUrl string + var apiKey string + var modelSlug string + + // User specified their own API key, use the OpenAI-compatible endpoint + if llmProvider != nil && llmProvider.IsCustom() { + baseUrl = cfg.OpenAIBaseURL + apiKey = cfg.OpenAIAPIKey + modelSlug = "gpt-5-nano" + // Use the default inference endpoint + } else { + // suffix needed for cloudflare gateway + baseUrl = cfg.InferenceBaseURL + "/openrouter" + apiKey = cfg.InferenceAPIKey + modelSlug = "openai/gpt-5-nano" + } + + CheckOpenAIWorksV2( + openai.NewClient( + option.WithBaseURL(baseUrl), + option.WithAPIKey(apiKey), + ), + baseUrl, + modelSlug, + logger, ) - CheckOpenAIWorksV2(oaiClient, logger) toolRegistry := initializeToolkitV2(db, projectService, cfg, logger) toolCallHandler := handler.NewToolCallHandlerV2(toolRegistry) diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index e502cb21..c3ab40dd 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -87,13 +87,13 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) } } -func CheckOpenAIWorksV2(oaiClient openaiv3.Client, logger *logger.Logger) { - logger.Info("[AI Client V2] checking if openai client works") +func CheckOpenAIWorksV2(oaiClient openaiv3.Client, baseUrl string, model string, logger *logger.Logger) { + logger.Info("[AI Client V2] checking if openai client works with " + baseUrl + " ..") chatCompletion, err := oaiClient.Chat.Completions.New(context.TODO(), openaiv3.ChatCompletionNewParams{ Messages: []openaiv3.ChatCompletionMessageParamUnion{ openaiv3.UserMessage("Say 'openai client works'"), }, - Model: "openai/gpt-5-nano", + Model: model, }) if err != nil { logger.Errorf("[AI Client V2] openai client does not work: %v", err) diff --git a/webapp/_webapp/src/views/settings/setting-text-input.tsx b/webapp/_webapp/src/views/settings/setting-text-input.tsx index 893df676..e2fe2c00 100644 --- a/webapp/_webapp/src/views/settings/setting-text-input.tsx +++ b/webapp/_webapp/src/views/settings/setting-text-input.tsx @@ -3,6 +3,8 @@ import { Button, cn } from "@heroui/react"; import { useSettingStore } from "../../stores/setting-store"; import { Settings } from "../../pkg/gen/apiclient/user/v1/user_pb"; import { PlainMessage } from "../../query/types"; +import { useConversationStore } from "../../stores/conversation/conversation-store"; +import { listSupportedModels } from "../../query/api"; type SettingKey = keyof PlainMessage; @@ -27,6 +29,7 @@ export function createSettingsTextInput(settingKey: K) { password = false, }: SettingsTextInputProps) { const { settings, isUpdating, updateSettings } = useSettingStore(); + const { setCurrentConversation } = useConversationStore(); const [value, setValue] = useState(""); const [originalValue, setOriginalValue] = useState(""); const [isEditing, setIsEditing] = useState(false); @@ -43,11 +46,39 @@ export function createSettingsTextInput(settingKey: K) { const valueChanged = value !== originalValue; + // helper normalizes model by retrieving the model (assumed to be the last segment, if '/' present) + const normalizeModelId = (modelSlug: string) => + modelSlug.toLowerCase().trim().split("/").filter(Boolean).pop()!; + const saveSettings = useCallback(async () => { await updateSettings({ [settingKey]: value.trim() } as Partial>); setOriginalValue(value.trim()); setIsEditing(false); - }, [value, updateSettings]); // settingKey is an outer scope value, not a dependency + + // If openaiApiKey was updated, fetch new model list and update current model slug + if (settingKey === "openaiApiKey") { + const response = await listSupportedModels({}); + if (response.models?.length) { + const { currentConversation: latest } = useConversationStore.getState(); + // try to find a model that matches the current slug + // we don't do exact match but attempt exact suffix match (case insensitive); as per convention + // we don't assume any provided prefix (e.g. openai/, quen/) but fair to assume model name is suffix + const currentId = normalizeModelId(latest.modelSlug); + const matchingModel = response.models.find(m => + normalizeModelId(m.name) === currentId + ); + // fall back to the first model in the list + const newSlug = matchingModel?.slug ?? response.models[0].slug; + + if (newSlug !== latest.modelSlug) { + setCurrentConversation({ + ...latest, + modelSlug: newSlug, + }); + } + } + } + }, [value, settingKey, updateSettings]); // settingKey is an outer scope value, not a dependency const handleEdit = useCallback(() => { setIsEditing(true);