diff --git a/pkg/config/gather.go b/pkg/config/gather.go index 459747084..b4aabbc89 100644 --- a/pkg/config/gather.go +++ b/pkg/config/gather.go @@ -71,7 +71,9 @@ func GatherEnvVarsForModels(cfg *latest.Config) []string { case "anthropic": requiredEnv["ANTHROPIC_API_KEY"] = true case "google": - requiredEnv["GOOGLE_API_KEY"] = true + if model.ProviderOpts["project"] == nil && model.ProviderOpts["location"] == nil { + requiredEnv["GOOGLE_API_KEY"] = true + } case "mistral": requiredEnv["MISTRAL_API_KEY"] = true } diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 6d3289b26..1965badbc 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -52,7 +52,7 @@ type ModelConfig struct { BaseURL string `json:"base_url,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` TokenKey string `json:"token_key,omitempty"` - // ProviderOpts allows provider-specific options. Currently used for "dmr" provider only. + // ProviderOpts allows provider-specific options. ProviderOpts map[string]any `json:"provider_opts,omitempty"` TrackUsage *bool `json:"track_usage,omitempty"` // ThinkingBudget controls reasoning effort/budget: diff --git a/pkg/config/testdata/env/google_model.yaml b/pkg/config/testdata/env/google_model.yaml index d37eb6251..d20d55714 100755 --- a/pkg/config/testdata/env/google_model.yaml +++ b/pkg/config/testdata/env/google_model.yaml @@ -9,3 +9,7 @@ models: gemini: provider: google model: gemini-2-0 + # For Vertex AI + # provider_opts: + # project: my-gcp-project + # location: us-central1 diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 73c82d3c6..9842c608d 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -48,15 +48,51 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro var clientFn func(context.Context) (*genai.Client, error) if gateway := globalOptions.Gateway(); gateway == "" { - apiKey := env.Get(ctx, "GOOGLE_API_KEY") - if apiKey == "" { - return nil, errors.New("GOOGLE_API_KEY environment variable is required") + var ( + httpClient *http.Client + backend genai.Backend + apiKey string + project string + location string + ) + // project/location take priority over API key, like in the genai client. + if cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil { + var err error + + project, err = environment.Expand(ctx, providerOption(cfg, "project"), env) + if err != nil { + return nil, fmt.Errorf("expanding project: %w", err) + } + if project == "" { + return nil, errors.New("project must be set") + } + + location, err = environment.Expand(ctx, providerOption(cfg, "location"), env) + if err != nil { + return nil, fmt.Errorf("expanding location: %w", err) + } + if location == "" { + return nil, errors.New("location must be set") + } + + backend = genai.BackendVertexAI + httpClient = nil // Use default client + } else { + apiKey = env.Get(ctx, "GOOGLE_API_KEY") + if apiKey == "" { + return nil, errors.New("GOOGLE_API_KEY environment variable is required") + } + + backend = genai.BackendGeminiAPI + httpClient = httpclient.NewHTTPClient() } client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: apiKey, - Backend: genai.BackendGeminiAPI, - HTTPClient: httpclient.NewHTTPClient(), + Project: project, + Location: location, + Backend: backend, + HTTPClient: httpClient, HTTPOptions: genai.HTTPOptions{ BaseURL: cfg.BaseURL, }, @@ -548,3 +584,11 @@ func defaultsTo(value, defaultValue string) string { } return defaultValue } + +func providerOption(cfg *latest.ModelConfig, name string) string { + v := cfg.ProviderOpts[name] + if v, ok := v.(string); ok { + return v + } + return "" +}