Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/config/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func TestV2Commands_AllForms(t *testing.T) {
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml", nil))
require.NoError(t, err)

// Test simple map format
Expand Down Expand Up @@ -38,7 +38,7 @@ func TestV2Commands_AllForms(t *testing.T) {
}

func TestV2Commands_DisplayText(t *testing.T) {
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml", nil))
require.NoError(t, err)

// Simple format: DisplayText returns the instruction
Expand All @@ -51,7 +51,7 @@ func TestV2Commands_DisplayText(t *testing.T) {
}

func TestMigrate_v1_Commands_AllForms(t *testing.T) {
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v1.yaml", nil))
require.NoError(t, err)

require.Equal(t, "root", cfg.Agents[0].Name)
Expand All @@ -69,7 +69,7 @@ func TestMigrate_v1_Commands_AllForms(t *testing.T) {
}

func TestMigrate_v0_Commands_AllForms(t *testing.T) {
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v0.yaml", nil))
require.NoError(t, err)

require.Equal(t, "root", cfg.Agents[0].Name)
Expand Down
62 changes: 12 additions & 50 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ import (
"log/slog"
"maps"
"net/url"
"path/filepath"
"slices"
"strings"

"github.com/goccy/go-yaml"

hclconv "github.com/docker/docker-agent/pkg/config/hcl"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
)
Expand All @@ -25,17 +23,6 @@ func Load(ctx context.Context, source Source) (*latest.Config, error) {
return nil, err
}

// Configurations may be authored in HCL as an alternative to YAML.
// Detect the format from the source name extension or, when no hint is
// available (OCI artifacts, etc.), from the content itself, then
// transparently convert to YAML for the rest of the pipeline.
if isHCLSource(source.Name(), data) {
data, err = hclconv.ToYAML(data, source.Name())
if err != nil {
return nil, fmt.Errorf("parsing HCL config file: %w", err)
}
}

var raw struct {
Version string `yaml:"version,omitempty"`
}
Expand All @@ -56,6 +43,10 @@ func Load(ctx context.Context, source Source) (*latest.Config, error) {

config.Version = raw.Version

if err := ExpandConfig(ctx, &config, source.EnvProvider()); err != nil {
return nil, fmt.Errorf("expanding config: %w", err)
}

if err := validateConfig(&config); err != nil {
return nil, err
}
Expand All @@ -76,7 +67,7 @@ func CheckRequiredEnvVars(ctx context.Context, cfg *latest.Config, modelsGateway
missing, err := gatherMissingEnvVars(ctx, cfg, modelsGateway, env)
if err != nil {
// If there's a tool preflight error, log it but continue
slog.WarnContext(ctx, "Failed to preflight toolset environment variables; continuing", "error", err)
slog.Warn("Failed to preflight toolset environment variables; continuing", "error", err)
}

// Return error if there are missing environment variables
Expand Down Expand Up @@ -179,16 +170,6 @@ func validateConfig(cfg *latest.Config) error {
return nil
}

// isHCLSource reports whether the configuration data should be parsed as HCL
// rather than YAML. The decision is based first on the source name extension,
// and then on a content-based heuristic when no extension hint is available.
func isHCLSource(name string, data []byte) bool {
if strings.EqualFold(filepath.Ext(name), ".hcl") {
return true
}
return hclconv.LooksLikeHCL(data)
}

// providerAPITypes are the allowed values for api_type in provider configs
var providerAPITypes = map[string]bool{
"": true, // empty is allowed (defaults to openai_chatcompletions)
Expand All @@ -208,19 +189,17 @@ func validateProviders(cfg *latest.Config) error {
return fmt.Errorf("provider '%s': %w", name, err)
}

// Validate api_type if set
// Validate api_type
if !providerAPITypes[provCfg.APIType] {
return fmt.Errorf("provider '%s': invalid api_type '%s' (must be one of: openai_chatcompletions, openai_responses)", name, provCfg.APIType)
}

// base_url is required for OpenAI-compatible providers (the default)
// but optional for native providers like anthropic, google, amazon-bedrock
if provCfg.BaseURL != "" {
if _, err := url.Parse(provCfg.BaseURL); err != nil {
return fmt.Errorf("provider '%s': invalid base_url '%s': %w", name, provCfg.BaseURL, err)
}
} else if isOpenAICustomProvider(provCfg) {
return fmt.Errorf("provider '%s': base_url is required for OpenAI-compatible providers", name)
// base_url is required for custom providers
if provCfg.BaseURL == "" {
return fmt.Errorf("provider '%s': base_url is required", name)
}
if _, err := url.Parse(provCfg.BaseURL); err != nil {
return fmt.Errorf("provider '%s': invalid base_url '%s': %w", name, provCfg.BaseURL, err)
}

// token_key is optional - if not set, requests will be sent without bearer token
Expand All @@ -229,18 +208,6 @@ func validateProviders(cfg *latest.Config) error {
return nil
}

// isOpenAICustomProvider returns true if the provider config describes an OpenAI-compatible
// custom provider (i.e., Provider is empty or "openai", or api_type is explicitly set to an
// OpenAI schema). These providers require a base_url because they don't have a built-in default.
func isOpenAICustomProvider(cfg latest.ProviderConfig) bool {
// If api_type is explicitly set, it's an OpenAI-compatible provider
if cfg.APIType != "" {
return true
}
// If provider is empty (defaults to openai) or explicitly "openai"
return cfg.Provider == "" || cfg.Provider == "openai"
}

// validateProviderName validates that a provider name is valid
func validateProviderName(name string) error {
trimmed := strings.TrimSpace(name)
Expand Down Expand Up @@ -270,10 +237,5 @@ func validateSkillsConfiguration(_ string, agent *latest.AgentConfig) error {
return fmt.Errorf("agent '%s' has unknown skills source '%s' (must be 'local' or an HTTP/HTTPS URL)", agent.Name, source)
}
}
for _, name := range agent.Skills.Include {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("agent '%s' has an empty skills entry", agent.Name)
}
}
return nil
}
76 changes: 18 additions & 58 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func TestAutoRegisterModels(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Models, 2)
Expand All @@ -28,7 +28,7 @@ func TestAutoRegisterModels(t *testing.T) {
func TestAutoRegisterAlloy(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister_alloy.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister_alloy.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Models, 2)
Expand All @@ -41,7 +41,7 @@ func TestAutoRegisterAlloy(t *testing.T) {
func TestAlloyModelComposition(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_composition.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_composition.yaml", nil))
require.NoError(t, err)

// The alloy model should be expanded to its constituent models
Expand All @@ -57,7 +57,7 @@ func TestAlloyModelComposition(t *testing.T) {
func TestAlloyModelNestedComposition(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_nested.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_nested.yaml", nil))
require.NoError(t, err)

// The nested alloy should be fully expanded to all constituent models
Expand All @@ -72,7 +72,7 @@ func TestAlloyModelNestedComposition(t *testing.T) {
func TestMigrate_v0_v1_provider(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v0.yaml", nil))
require.NoError(t, err)

assert.Equal(t, "openai", cfg.Models["gpt"].Provider)
Expand All @@ -81,7 +81,7 @@ func TestMigrate_v0_v1_provider(t *testing.T) {
func TestMigrate_v1_provider(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v1.yaml", nil))
require.NoError(t, err)

assert.Equal(t, "openai", cfg.Models["gpt"].Provider)
Expand All @@ -90,7 +90,7 @@ func TestMigrate_v1_provider(t *testing.T) {
func TestMigrate_v0_v1_todo(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v0.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -102,7 +102,7 @@ func TestMigrate_v0_v1_todo(t *testing.T) {
func TestMigrate_v1_todo(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v1.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -114,7 +114,7 @@ func TestMigrate_v1_todo(t *testing.T) {
func TestMigrate_v0_v1_shared_todo(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v0.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -126,7 +126,7 @@ func TestMigrate_v0_v1_shared_todo(t *testing.T) {
func TestMigrate_v1_shared_todo(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v1.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -138,7 +138,7 @@ func TestMigrate_v1_shared_todo(t *testing.T) {
func TestMigrate_v0_v1_think(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/think_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/think_v0.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -149,7 +149,7 @@ func TestMigrate_v0_v1_think(t *testing.T) {
func TestMigrate_v1_think(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/think_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/think_v1.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -160,7 +160,7 @@ func TestMigrate_v1_think(t *testing.T) {
func TestMigrate_v0_v1_memory(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v0.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v0.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -172,7 +172,7 @@ func TestMigrate_v0_v1_memory(t *testing.T) {
func TestMigrate_v1_memory(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v1.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v1.yaml", nil))
require.NoError(t, err)

assert.Len(t, cfg.Agents.First().Toolsets, 2)
Expand All @@ -184,7 +184,7 @@ func TestMigrate_v1_memory(t *testing.T) {
func TestMigrate_v1(t *testing.T) {
t.Parallel()

_, err := Load(t.Context(), NewFileSource("testdata/v1.yaml"))
_, err := Load(t.Context(), NewFileSource("testdata/v1.yaml", nil))
require.NoError(t, err)
}

Expand Down Expand Up @@ -271,7 +271,7 @@ func TestCheckRequiredEnvVars(t *testing.T) {
t.Run(test.yaml, func(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/env/"+test.yaml))
cfg, err := Load(t.Context(), NewFileSource("testdata/env/"+test.yaml, nil))
require.NoError(t, err)

err = CheckRequiredEnvVars(t.Context(), cfg, "", &noEnvProvider{})
Expand All @@ -294,7 +294,7 @@ func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) {
t.Run("with token", func(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml", nil))
require.NoError(t, err)

env := &fakeEnvProvider{vars: map[string]string{
Expand All @@ -308,7 +308,7 @@ func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) {
t.Run("without token", func(t *testing.T) {
t.Parallel()

cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml"))
cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml", nil))
require.NoError(t, err)

err = CheckRequiredEnvVars(t.Context(), cfg, "gateway:8080", &noEnvProvider{})
Expand Down Expand Up @@ -680,46 +680,6 @@ func TestProviders_Validation(t *testing.T) {
},
wantErr: "name cannot contain '/'",
},
{
name: "valid anthropic provider without base_url",
providers: map[string]latest.ProviderConfig{
"my_anthropic": {
Provider: "anthropic",
TokenKey: "MY_ANTHROPIC_KEY",
},
},
wantErr: "",
},
{
name: "valid google provider with defaults",
providers: map[string]latest.ProviderConfig{
"my_google": {
Provider: "google",
},
},
wantErr: "",
},
{
name: "openai provider without base_url requires it",
providers: map[string]latest.ProviderConfig{
"my_openai": {
Provider: "openai",
},
},
wantErr: "base_url is required",
},
{
name: "provider with model defaults",
providers: map[string]latest.ProviderConfig{
"my_anthropic": {
Provider: "anthropic",
TokenKey: "MY_KEY",
MaxTokens: new(int64),
Temperature: new(float64),
},
},
wantErr: "",
},
}

for _, tt := range tests {
Expand Down
Loading