Skip to content

Commit e343cd7

Browse files
authored
Merge pull request #1073 from dgageot/gemini-auth
Support Vertex API
2 parents 0954300 + f5967a7 commit e343cd7

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

pkg/config/gather.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ func GatherEnvVarsForModels(cfg *latest.Config) []string {
7171
case "anthropic":
7272
requiredEnv["ANTHROPIC_API_KEY"] = true
7373
case "google":
74-
requiredEnv["GOOGLE_API_KEY"] = true
74+
if model.ProviderOpts["project"] == nil && model.ProviderOpts["location"] == nil {
75+
requiredEnv["GOOGLE_API_KEY"] = true
76+
}
7577
case "mistral":
7678
requiredEnv["MISTRAL_API_KEY"] = true
7779
}

pkg/config/latest/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ type ModelConfig struct {
5353
BaseURL string `json:"base_url,omitempty"`
5454
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
5555
TokenKey string `json:"token_key,omitempty"`
56-
// ProviderOpts allows provider-specific options. Currently used for "dmr" provider only.
56+
// ProviderOpts allows provider-specific options.
5757
ProviderOpts map[string]any `json:"provider_opts,omitempty"`
5858
TrackUsage *bool `json:"track_usage,omitempty"`
5959
// ThinkingBudget controls reasoning effort/budget:

pkg/config/testdata/env/google_model.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ models:
99
gemini:
1010
provider: google
1111
model: gemini-2-0
12+
# For Vertex AI
13+
# provider_opts:
14+
# project: my-gcp-project
15+
# location: us-central1

pkg/model/provider/gemini/client.go

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,51 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
4848

4949
var clientFn func(context.Context) (*genai.Client, error)
5050
if gateway := globalOptions.Gateway(); gateway == "" {
51-
apiKey := env.Get(ctx, "GOOGLE_API_KEY")
52-
if apiKey == "" {
53-
return nil, errors.New("GOOGLE_API_KEY environment variable is required")
51+
var (
52+
httpClient *http.Client
53+
backend genai.Backend
54+
apiKey string
55+
project string
56+
location string
57+
)
58+
// project/location take priority over API key, like in the genai client.
59+
if cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil {
60+
var err error
61+
62+
project, err = environment.Expand(ctx, providerOption(cfg, "project"), env)
63+
if err != nil {
64+
return nil, fmt.Errorf("expanding project: %w", err)
65+
}
66+
if project == "" {
67+
return nil, errors.New("project must be set")
68+
}
69+
70+
location, err = environment.Expand(ctx, providerOption(cfg, "location"), env)
71+
if err != nil {
72+
return nil, fmt.Errorf("expanding location: %w", err)
73+
}
74+
if location == "" {
75+
return nil, errors.New("location must be set")
76+
}
77+
78+
backend = genai.BackendVertexAI
79+
httpClient = nil // Use default client
80+
} else {
81+
apiKey = env.Get(ctx, "GOOGLE_API_KEY")
82+
if apiKey == "" {
83+
return nil, errors.New("GOOGLE_API_KEY environment variable is required")
84+
}
85+
86+
backend = genai.BackendGeminiAPI
87+
httpClient = httpclient.NewHTTPClient()
5488
}
5589

5690
client, err := genai.NewClient(ctx, &genai.ClientConfig{
5791
APIKey: apiKey,
58-
Backend: genai.BackendGeminiAPI,
59-
HTTPClient: httpclient.NewHTTPClient(),
92+
Project: project,
93+
Location: location,
94+
Backend: backend,
95+
HTTPClient: httpClient,
6096
HTTPOptions: genai.HTTPOptions{
6197
BaseURL: cfg.BaseURL,
6298
},
@@ -548,3 +584,11 @@ func defaultsTo(value, defaultValue string) string {
548584
}
549585
return defaultValue
550586
}
587+
588+
func providerOption(cfg *latest.ModelConfig, name string) string {
589+
v := cfg.ProviderOpts[name]
590+
if v, ok := v.(string); ok {
591+
return v
592+
}
593+
return ""
594+
}

0 commit comments

Comments
 (0)