Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions internal/services/toolkit/client/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,30 @@ func NewAIClientV2(
logger *logger.Logger,
) *AIClientV2 {
database := db.Database("paperdebugger")

llmProvider := &models.LLMProviderConfig{
APIKey: cfg.OpenAIAPIKey,
}

var baseUrl string
var apiKey string
var modelSlug string

if llmProvider != nil && llmProvider.IsCustom() {
baseUrl = cfg.OpenAIBaseURL
apiKey = cfg.OpenAIAPIKey
modelSlug = "gpt-5-nano"
} else {
baseUrl = cfg.InferenceBaseURL + "/openrouter"
apiKey = cfg.InferenceAPIKey
modelSlug = "openai/gpt-5-nano"
}

oaiClient := openai.NewClient(
option.WithBaseURL(cfg.InferenceBaseURL+"/openrouter"),
option.WithAPIKey(cfg.InferenceAPIKey),
option.WithBaseURL(baseUrl),
option.WithAPIKey(apiKey),
)
CheckOpenAIWorksV2(oaiClient, logger)
CheckOpenAIWorksV2(oaiClient, baseUrl, modelSlug, logger)

toolRegistry := initializeToolkitV2(db, projectService, cfg, logger)
toolCallHandler := handler.NewToolCallHandlerV2(toolRegistry)
Expand Down
3 changes: 2 additions & 1 deletion internal/services/toolkit/client/completion_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
}()

oaiClient := a.GetOpenAIClient(llmProvider)
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry)
var isCustomModel bool = llmProvider != nil && llmProvider.IsCustom()
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, isCustomModel)

for {
params.Messages = openaiChatHistory
Expand Down
17 changes: 11 additions & 6 deletions internal/services/toolkit/client/utils_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This file contains utility functions for the client package. (Mainly miscellaneo
It is used to append assistant responses to both OpenAI and in-app chat histories, and to create response items for chat interactions.
*/
import (
"path"
"context"
"fmt"
"paperdebugger/internal/libs/cfg"
Expand All @@ -15,7 +16,6 @@ import (
"paperdebugger/internal/services/toolkit/registry"
"paperdebugger/internal/services/toolkit/tools/xtramcp"
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
"strings"
"time"

openaiv3 "github.com/openai/openai-go/v3"
Expand Down Expand Up @@ -52,7 +52,12 @@ func appendAssistantTextResponseV2(openaiChatHistory *OpenAIChatHistory, inappCh
})
}

func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) openaiv3.ChatCompletionNewParams {
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, isCustomModel bool) openaiv3.ChatCompletionNewParams {
// If custom model is used, strip prefix (eg "openai/gpt-4o" -> "gpt-4o")
if isCustomModel {
modelSlug = path.Base(modelSlug)
}

var reasoningModels = []string{
"gpt-5",
"gpt-5-mini",
Expand All @@ -66,7 +71,7 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2)
"codex-mini-latest",
}
for _, model := range reasoningModels {
if strings.Contains(modelSlug, model) {
if modelSlug == model {
return openaiv3.ChatCompletionNewParams{
Model: modelSlug,
MaxCompletionTokens: openaiv3.Int(4000),
Expand All @@ -87,13 +92,13 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2)
}
}

func CheckOpenAIWorksV2(oaiClient openaiv3.Client, logger *logger.Logger) {
logger.Info("[AI Client V2] checking if openai client works")
func CheckOpenAIWorksV2(oaiClient openaiv3.Client, baseUrl string, model string, logger *logger.Logger) {
logger.Info("[AI Client V2] checking if openai client works with " + baseUrl + "..")
chatCompletion, err := oaiClient.Chat.Completions.New(context.TODO(), openaiv3.ChatCompletionNewParams{
Messages: []openaiv3.ChatCompletionMessageParamUnion{
openaiv3.UserMessage("Say 'openai client works'"),
},
Model: "openai/gpt-5-nano",
Model: model,
})
if err != nil {
logger.Errorf("[AI Client V2] openai client does not work: %v", err)
Expand Down