diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index 8279b6bc..eb83f793 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -63,11 +63,30 @@ func NewAIClientV2( logger *logger.Logger, ) *AIClientV2 { database := db.Database("paperdebugger") + + llmProvider := &models.LLMProviderConfig{ + APIKey: cfg.OpenAIAPIKey, + } + + var baseUrl string + var apiKey string + var modelSlug string + + if llmProvider != nil && llmProvider.IsCustom() { + baseUrl = cfg.OpenAIBaseURL + apiKey = cfg.OpenAIAPIKey + modelSlug = "gpt-5-nano" + } else { + baseUrl = cfg.InferenceBaseURL + "/openrouter" + apiKey = cfg.InferenceAPIKey + modelSlug = "openai/gpt-5-nano" + } + oaiClient := openai.NewClient( - option.WithBaseURL(cfg.InferenceBaseURL+"/openrouter"), - option.WithAPIKey(cfg.InferenceAPIKey), + option.WithBaseURL(baseUrl), + option.WithAPIKey(apiKey), ) - CheckOpenAIWorksV2(oaiClient, logger) + CheckOpenAIWorksV2(oaiClient, baseUrl, modelSlug, logger) toolRegistry := initializeToolkitV2(db, projectService, cfg, logger) toolCallHandler := handler.NewToolCallHandlerV2(toolRegistry) diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index e7e5b7b2..a428edc6 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -65,7 +65,8 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream }() oaiClient := a.GetOpenAIClient(llmProvider) - params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry) + var isCustomModel bool = llmProvider != nil && llmProvider.IsCustom() + params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, isCustomModel) for { params.Messages = openaiChatHistory diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index e502cb21..2fa04b94 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -6,6 +6,7 @@ This file contains utility functions for the client package. (Mainly miscellaneo It is used to append assistant responses to both OpenAI and in-app chat histories, and to create response items for chat interactions. */ import ( + "path" "context" "fmt" "paperdebugger/internal/libs/cfg" @@ -15,7 +16,6 @@ import ( "paperdebugger/internal/services/toolkit/registry" "paperdebugger/internal/services/toolkit/tools/xtramcp" chatv2 "paperdebugger/pkg/gen/api/chat/v2" - "strings" "time" openaiv3 "github.com/openai/openai-go/v3" @@ -52,7 +52,12 @@ func appendAssistantTextResponseV2(openaiChatHistory *OpenAIChatHistory, inappCh }) } -func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) openaiv3.ChatCompletionNewParams { +func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, isCustomModel bool) openaiv3.ChatCompletionNewParams { + // If custom model is used, strip prefix (eg "openai/gpt-4o" -> "gpt-4o") + if isCustomModel { + modelSlug = path.Base(modelSlug) + } + var reasoningModels = []string{ "gpt-5", "gpt-5-mini", @@ -66,7 +71,7 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) "codex-mini-latest", } for _, model := range reasoningModels { - if strings.Contains(modelSlug, model) { + if modelSlug == model { return openaiv3.ChatCompletionNewParams{ Model: modelSlug, MaxCompletionTokens: openaiv3.Int(4000), @@ -87,13 +92,13 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) } } -func CheckOpenAIWorksV2(oaiClient openaiv3.Client, logger *logger.Logger) { - logger.Info("[AI Client V2] checking if openai client works") +func CheckOpenAIWorksV2(oaiClient openaiv3.Client, baseUrl string, model string, logger *logger.Logger) { + logger.Info("[AI Client V2] checking if openai client works with " + baseUrl + "..") chatCompletion, err := oaiClient.Chat.Completions.New(context.TODO(), openaiv3.ChatCompletionNewParams{ Messages: []openaiv3.ChatCompletionMessageParamUnion{ openaiv3.UserMessage("Say 'openai client works'"), }, - Model: "openai/gpt-5-nano", + Model: model, }) if err != nil { logger.Errorf("[AI Client V2] openai client does not work: %v", err)