diff --git a/pkg/config/commands_test.go b/pkg/config/commands_test.go index 9969bc4c7..aa792a687 100644 --- a/pkg/config/commands_test.go +++ b/pkg/config/commands_test.go @@ -7,7 +7,7 @@ import ( ) func TestV2Commands_AllForms(t *testing.T) { - cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml", nil)) require.NoError(t, err) // Test simple map format @@ -38,7 +38,7 @@ func TestV2Commands_AllForms(t *testing.T) { } func TestV2Commands_DisplayText(t *testing.T) { - cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v2.yaml", nil)) require.NoError(t, err) // Simple format: DisplayText returns the instruction @@ -51,7 +51,7 @@ func TestV2Commands_DisplayText(t *testing.T) { } func TestMigrate_v1_Commands_AllForms(t *testing.T) { - cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v1.yaml", nil)) require.NoError(t, err) require.Equal(t, "root", cfg.Agents[0].Name) @@ -69,7 +69,7 @@ func TestMigrate_v1_Commands_AllForms(t *testing.T) { } func TestMigrate_v0_Commands_AllForms(t *testing.T) { - cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/commands_v0.yaml", nil)) require.NoError(t, err) require.Equal(t, "root", cfg.Agents[0].Name) diff --git a/pkg/config/config.go b/pkg/config/config.go index 721d4171b..e42b47053 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -8,13 +8,11 @@ import ( "log/slog" "maps" "net/url" - "path/filepath" "slices" "strings" "github.com/goccy/go-yaml" - hclconv "github.com/docker/docker-agent/pkg/config/hcl" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" ) @@ -25,17 +23,6 @@ func Load(ctx context.Context, source Source) (*latest.Config, error) { return nil, err } - // Configurations may be authored in HCL as an alternative to YAML. - // Detect the format from the source name extension or, when no hint is - // available (OCI artifacts, etc.), from the content itself, then - // transparently convert to YAML for the rest of the pipeline. - if isHCLSource(source.Name(), data) { - data, err = hclconv.ToYAML(data, source.Name()) - if err != nil { - return nil, fmt.Errorf("parsing HCL config file: %w", err) - } - } - var raw struct { Version string `yaml:"version,omitempty"` } @@ -56,6 +43,10 @@ func Load(ctx context.Context, source Source) (*latest.Config, error) { config.Version = raw.Version + if err := ExpandConfig(ctx, &config, source.EnvProvider()); err != nil { + return nil, fmt.Errorf("expanding config: %w", err) + } + if err := validateConfig(&config); err != nil { return nil, err } @@ -76,7 +67,7 @@ func CheckRequiredEnvVars(ctx context.Context, cfg *latest.Config, modelsGateway missing, err := gatherMissingEnvVars(ctx, cfg, modelsGateway, env) if err != nil { // If there's a tool preflight error, log it but continue - slog.WarnContext(ctx, "Failed to preflight toolset environment variables; continuing", "error", err) + slog.Warn("Failed to preflight toolset environment variables; continuing", "error", err) } // Return error if there are missing environment variables @@ -179,16 +170,6 @@ func validateConfig(cfg *latest.Config) error { return nil } -// isHCLSource reports whether the configuration data should be parsed as HCL -// rather than YAML. The decision is based first on the source name extension, -// and then on a content-based heuristic when no extension hint is available. -func isHCLSource(name string, data []byte) bool { - if strings.EqualFold(filepath.Ext(name), ".hcl") { - return true - } - return hclconv.LooksLikeHCL(data) -} - // providerAPITypes are the allowed values for api_type in provider configs var providerAPITypes = map[string]bool{ "": true, // empty is allowed (defaults to openai_chatcompletions) @@ -208,19 +189,17 @@ func validateProviders(cfg *latest.Config) error { return fmt.Errorf("provider '%s': %w", name, err) } - // Validate api_type if set + // Validate api_type if !providerAPITypes[provCfg.APIType] { return fmt.Errorf("provider '%s': invalid api_type '%s' (must be one of: openai_chatcompletions, openai_responses)", name, provCfg.APIType) } - // base_url is required for OpenAI-compatible providers (the default) - // but optional for native providers like anthropic, google, amazon-bedrock - if provCfg.BaseURL != "" { - if _, err := url.Parse(provCfg.BaseURL); err != nil { - return fmt.Errorf("provider '%s': invalid base_url '%s': %w", name, provCfg.BaseURL, err) - } - } else if isOpenAICustomProvider(provCfg) { - return fmt.Errorf("provider '%s': base_url is required for OpenAI-compatible providers", name) + // base_url is required for custom providers + if provCfg.BaseURL == "" { + return fmt.Errorf("provider '%s': base_url is required", name) + } + if _, err := url.Parse(provCfg.BaseURL); err != nil { + return fmt.Errorf("provider '%s': invalid base_url '%s': %w", name, provCfg.BaseURL, err) } // token_key is optional - if not set, requests will be sent without bearer token @@ -229,18 +208,6 @@ func validateProviders(cfg *latest.Config) error { return nil } -// isOpenAICustomProvider returns true if the provider config describes an OpenAI-compatible -// custom provider (i.e., Provider is empty or "openai", or api_type is explicitly set to an -// OpenAI schema). These providers require a base_url because they don't have a built-in default. -func isOpenAICustomProvider(cfg latest.ProviderConfig) bool { - // If api_type is explicitly set, it's an OpenAI-compatible provider - if cfg.APIType != "" { - return true - } - // If provider is empty (defaults to openai) or explicitly "openai" - return cfg.Provider == "" || cfg.Provider == "openai" -} - // validateProviderName validates that a provider name is valid func validateProviderName(name string) error { trimmed := strings.TrimSpace(name) @@ -270,10 +237,5 @@ func validateSkillsConfiguration(_ string, agent *latest.AgentConfig) error { return fmt.Errorf("agent '%s' has unknown skills source '%s' (must be 'local' or an HTTP/HTTPS URL)", agent.Name, source) } } - for _, name := range agent.Skills.Include { - if strings.TrimSpace(name) == "" { - return fmt.Errorf("agent '%s' has an empty skills entry", agent.Name) - } - } return nil } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index d80d80490..0b8cb541a 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -15,7 +15,7 @@ import ( func TestAutoRegisterModels(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Models, 2) @@ -28,7 +28,7 @@ func TestAutoRegisterModels(t *testing.T) { func TestAutoRegisterAlloy(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister_alloy.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister_alloy.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Models, 2) @@ -41,7 +41,7 @@ func TestAutoRegisterAlloy(t *testing.T) { func TestAlloyModelComposition(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_composition.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_composition.yaml", nil)) require.NoError(t, err) // The alloy model should be expanded to its constituent models @@ -57,7 +57,7 @@ func TestAlloyModelComposition(t *testing.T) { func TestAlloyModelNestedComposition(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_nested.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/alloy_model_nested.yaml", nil)) require.NoError(t, err) // The nested alloy should be fully expanded to all constituent models @@ -72,7 +72,7 @@ func TestAlloyModelNestedComposition(t *testing.T) { func TestMigrate_v0_v1_provider(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v0.yaml", nil)) require.NoError(t, err) assert.Equal(t, "openai", cfg.Models["gpt"].Provider) @@ -81,7 +81,7 @@ func TestMigrate_v0_v1_provider(t *testing.T) { func TestMigrate_v1_provider(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/provider_v1.yaml", nil)) require.NoError(t, err) assert.Equal(t, "openai", cfg.Models["gpt"].Provider) @@ -90,7 +90,7 @@ func TestMigrate_v1_provider(t *testing.T) { func TestMigrate_v0_v1_todo(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v0.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -102,7 +102,7 @@ func TestMigrate_v0_v1_todo(t *testing.T) { func TestMigrate_v1_todo(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/todo_v1.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -114,7 +114,7 @@ func TestMigrate_v1_todo(t *testing.T) { func TestMigrate_v0_v1_shared_todo(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v0.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -126,7 +126,7 @@ func TestMigrate_v0_v1_shared_todo(t *testing.T) { func TestMigrate_v1_shared_todo(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/shared_todo_v1.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -138,7 +138,7 @@ func TestMigrate_v1_shared_todo(t *testing.T) { func TestMigrate_v0_v1_think(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/think_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/think_v0.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -149,7 +149,7 @@ func TestMigrate_v0_v1_think(t *testing.T) { func TestMigrate_v1_think(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/think_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/think_v1.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -160,7 +160,7 @@ func TestMigrate_v1_think(t *testing.T) { func TestMigrate_v0_v1_memory(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v0.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v0.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -172,7 +172,7 @@ func TestMigrate_v0_v1_memory(t *testing.T) { func TestMigrate_v1_memory(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v1.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/memory_v1.yaml", nil)) require.NoError(t, err) assert.Len(t, cfg.Agents.First().Toolsets, 2) @@ -184,7 +184,7 @@ func TestMigrate_v1_memory(t *testing.T) { func TestMigrate_v1(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource("testdata/v1.yaml")) + _, err := Load(t.Context(), NewFileSource("testdata/v1.yaml", nil)) require.NoError(t, err) } @@ -271,7 +271,7 @@ func TestCheckRequiredEnvVars(t *testing.T) { t.Run(test.yaml, func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/env/"+test.yaml)) + cfg, err := Load(t.Context(), NewFileSource("testdata/env/"+test.yaml, nil)) require.NoError(t, err) err = CheckRequiredEnvVars(t.Context(), cfg, "", &noEnvProvider{}) @@ -294,7 +294,7 @@ func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) { t.Run("with token", func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml", nil)) require.NoError(t, err) env := &fakeEnvProvider{vars: map[string]string{ @@ -308,7 +308,7 @@ func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) { t.Run("without token", func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/env/all.yaml", nil)) require.NoError(t, err) err = CheckRequiredEnvVars(t.Context(), cfg, "gateway:8080", &noEnvProvider{}) @@ -680,46 +680,6 @@ func TestProviders_Validation(t *testing.T) { }, wantErr: "name cannot contain '/'", }, - { - name: "valid anthropic provider without base_url", - providers: map[string]latest.ProviderConfig{ - "my_anthropic": { - Provider: "anthropic", - TokenKey: "MY_ANTHROPIC_KEY", - }, - }, - wantErr: "", - }, - { - name: "valid google provider with defaults", - providers: map[string]latest.ProviderConfig{ - "my_google": { - Provider: "google", - }, - }, - wantErr: "", - }, - { - name: "openai provider without base_url requires it", - providers: map[string]latest.ProviderConfig{ - "my_openai": { - Provider: "openai", - }, - }, - wantErr: "base_url is required", - }, - { - name: "provider with model defaults", - providers: map[string]latest.ProviderConfig{ - "my_anthropic": { - Provider: "anthropic", - TokenKey: "MY_KEY", - MaxTokens: new(int64), - Temperature: new(float64), - }, - }, - wantErr: "", - }, } for _, tt := range tests { diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 5a9aee6ac..ea587d5e2 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -2,9 +2,7 @@ package config import ( "io/fs" - "os" "path/filepath" - "strings" "testing" "github.com/goccy/go-yaml" @@ -23,11 +21,8 @@ func collectExamples(t *testing.T) []string { if err != nil { return err } - if !d.IsDir() { - ext := filepath.Ext(path) - if ext == ".yaml" || ext == ".hcl" { - files = append(files, path) - } + if !d.IsDir() && filepath.Ext(path) == ".yaml" { + files = append(files, path) } return nil }) @@ -45,7 +40,7 @@ func TestParseExamples(t *testing.T) { t.Run(file, func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource(file)) + cfg, err := Load(t.Context(), NewFileSource(file, nil)) require.NoError(t, err) require.Equal(t, latest.Version, cfg.Version, "Version should be %d in %s", latest.Version, file) @@ -86,7 +81,8 @@ func TestParseExamplesAfterMarshalling(t *testing.T) { t.Run(file, func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource(file)) + src := NewFileSource(file, nil) + cfg, err := Load(t.Context(), NewFileSource(file, nil)) require.NoError(t, err) // Make sure that a config can be marshalled and parsed again. @@ -94,36 +90,8 @@ func TestParseExamplesAfterMarshalling(t *testing.T) { buf, err := yaml.Marshal(cfg) require.NoError(t, err) - // The marshalled bytes are always YAML, so re-load them under a - // .yaml-named source even when the original example was HCL. - name := strings.TrimSuffix(file, filepath.Ext(file)) + ".yaml" - _, err = Load(t.Context(), NewBytesSource(name, buf)) - require.NoError(t, err) - }) - } -} - -// TestHCLExamplesMatchYAML verifies that every .hcl example file produces a -// configuration identical to its .yaml sibling, ensuring the HCL surface -// stays in sync with the YAML schema. -func TestHCLExamplesMatchYAML(t *testing.T) { - for _, file := range collectExamples(t) { - if filepath.Ext(file) != ".hcl" { - continue - } - yamlFile := strings.TrimSuffix(file, ".hcl") + ".yaml" - if _, err := os.Stat(yamlFile); err != nil { - continue - } - t.Run(file, func(t *testing.T) { - t.Parallel() - - cfgHCL, err := Load(t.Context(), NewFileSource(file)) - require.NoError(t, err) - cfgYAML, err := Load(t.Context(), NewFileSource(yamlFile)) + _, err = Load(t.Context(), NewBytesSource(src.Name(), buf, nil)) require.NoError(t, err) - - require.Equal(t, cfgYAML, cfgHCL, "HCL config %s differs from YAML sibling %s", file, yamlFile) }) } } diff --git a/pkg/config/expand.go b/pkg/config/expand.go new file mode 100644 index 000000000..c90a3d899 --- /dev/null +++ b/pkg/config/expand.go @@ -0,0 +1,73 @@ +package config + +import ( + "context" + "fmt" + "reflect" + + "github.com/docker/docker-agent/pkg/environment" +) + +// ExpandConfig walks all exported string fields in a config struct (recursively) +// and applies environment.Expand to each. Struct fields tagged `expand:"false"` +// are skipped (useful for fields that are JS-only and handled by Layer 2). +// +// v must be a pointer to a struct. +func ExpandConfig(ctx context.Context, v any, env environment.Provider) error { + return walkStrings(ctx, reflect.ValueOf(v), env) +} + +func walkStrings(ctx context.Context, v reflect.Value, env environment.Provider) error { + // Dereference pointer. + for v.Kind() == reflect.Pointer { + if v.IsNil() { + return nil + } + v = v.Elem() + } + + switch v.Kind() { + case reflect.Struct: + t := v.Type() + for i := range v.NumField() { + field := t.Field(i) + // Skip unexported fields. + if !field.IsExported() { + continue + } + // Skip fields explicitly opted out of Layer 1 expansion. + if field.Tag.Get("expand") == "false" { + continue + } + if err := walkStrings(ctx, v.Field(i), env); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + } + + case reflect.String: + if !v.CanSet() { + return nil + } + expanded, _ := environment.Expand(ctx, v.String(), env) + v.SetString(expanded) + + case reflect.Slice: + for i := range v.Len() { + if err := walkStrings(ctx, v.Index(i), env); err != nil { + return fmt.Errorf("[%d]: %w", i, err) + } + } + + case reflect.Map: + for _, key := range v.MapKeys() { + elem := v.MapIndex(key) + // Map values aren't addressable; copy, expand, set. + if elem.Kind() == reflect.String { + expanded, _ := environment.Expand(ctx, elem.String(), env) + v.SetMapIndex(key, reflect.ValueOf(expanded)) + } + } + } + + return nil +} diff --git a/pkg/config/mcps_test.go b/pkg/config/mcps_test.go index 6965d9a1b..6b9215912 100644 --- a/pkg/config/mcps_test.go +++ b/pkg/config/mcps_test.go @@ -10,7 +10,7 @@ import ( func TestMCPDefinitions_BasicRef(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions.yaml", nil)) require.NoError(t, err) root, ok := cfg.Agents.Lookup("root") @@ -40,7 +40,7 @@ func TestMCPDefinitions_BasicRef(t *testing.T) { func TestMCPDefinitions_OverrideFields(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_override.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_override.yaml", nil)) require.NoError(t, err) root, ok := cfg.Agents.Lookup("root") @@ -65,7 +65,7 @@ func TestMCPDefinitions_OverrideFields(t *testing.T) { func TestMCPDefinitions_InvalidRef(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_ref.yaml")) + _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_ref.yaml", nil)) require.Error(t, err) assert.Contains(t, err.Error(), "non-existent MCP definition 'nonexistent'") } @@ -73,7 +73,7 @@ func TestMCPDefinitions_InvalidRef(t *testing.T) { func TestMCPDefinitions_InvalidDefinition(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_def.yaml")) + _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_def.yaml", nil)) require.Error(t, err) assert.Contains(t, err.Error(), "either command, remote or ref must be set") } @@ -81,7 +81,7 @@ func TestMCPDefinitions_InvalidDefinition(t *testing.T) { func TestMCPDefinitions_Remote(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_remote.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_remote.yaml", nil)) require.NoError(t, err) root, ok := cfg.Agents.Lookup("root") @@ -97,7 +97,7 @@ func TestMCPDefinitions_Remote(t *testing.T) { func TestMCPDefinitions_NoMCPsSection(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/autoregister.yaml", nil)) require.NoError(t, err) assert.Nil(t, cfg.MCPs) } @@ -105,7 +105,7 @@ func TestMCPDefinitions_NoMCPsSection(t *testing.T) { func TestMCPDefinitions_RejectsNonsenseFields(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_fields.yaml")) + _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_invalid_fields.yaml", nil)) require.Error(t, err) assert.Contains(t, err.Error(), "shared can only be used with type 'todo'") } @@ -113,7 +113,7 @@ func TestMCPDefinitions_RejectsNonsenseFields(t *testing.T) { func TestMCPDefinitions_RejectsMultipleSources(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_multiple_sources.yaml")) + _, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_multiple_sources.yaml", nil)) require.Error(t, err) assert.Contains(t, err.Error(), "either command, remote or ref must be set, but only one of those") } @@ -121,7 +121,7 @@ func TestMCPDefinitions_RejectsMultipleSources(t *testing.T) { func TestMCPDefinitions_EnvMerge(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_env_merge.yaml")) + cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_env_merge.yaml", nil)) require.NoError(t, err) root, ok := cfg.Agents.Lookup("root") @@ -136,25 +136,3 @@ func TestMCPDefinitions_EnvMerge(t *testing.T) { // Toolset-only key is preserved assert.Equal(t, "from_toolset", ts.Env["EXTRA"]) } - -func TestMCPDefinitions_WorkingDir(t *testing.T) { - t.Parallel() - - cfg, err := Load(t.Context(), NewFileSource("testdata/mcp_definitions_working_dir.yaml")) - require.NoError(t, err) - - // WorkingDir from the definition is inherited by the referencing toolset. - root, ok := cfg.Agents.Lookup("root") - require.True(t, ok) - require.Len(t, root.Toolsets, 1) - ts := root.Toolsets[0] - assert.Equal(t, "my-mcp-server", ts.Command) - assert.Equal(t, "./tools/mcp", ts.WorkingDir) - - // A toolset-level working_dir overrides the definition's value. - override, ok := cfg.Agents.Lookup("override") - require.True(t, ok) - require.Len(t, override.Toolsets, 1) - tsOverride := override.Toolsets[0] - assert.Equal(t, "./override/path", tsOverride.WorkingDir) -} diff --git a/pkg/config/resolve.go b/pkg/config/resolve.go index 8c4038dc9..dc1c5a3ba 100644 --- a/pkg/config/resolve.go +++ b/pkg/config/resolve.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "maps" - "net/url" "os" "path/filepath" "slices" @@ -64,7 +63,7 @@ func ResolveSources(agentsPath string, envProvider environment.Provider) (Source // resolve() only fails for non-OCI, non-URL, non-builtin references // that can't be made absolute. Try OCI as last resort. if IsOCIReference(agentsPath) { - return singleSource(reference.OciRefToFilename(agentsPath), NewOCISource(agentsPath)), nil + return singleSource(reference.OciRefToFilename(agentsPath), NewOCISource(agentsPath, envProvider)), nil } return nil, err } @@ -87,7 +86,7 @@ func Resolve(agentFilename string, envProvider environment.Provider) (Source, er resolvedPath, err := resolve(agentFilename) if err != nil { if IsOCIReference(agentFilename) { - return NewOCISource(agentFilename), nil + return NewOCISource(agentFilename, envProvider), nil } return nil, err } @@ -103,14 +102,13 @@ func Resolve(agentFilename string, envProvider environment.Provider) (Source, er func resolveOne(resolvedPath string, envProvider environment.Provider) (string, Source) { switch { case builtinAgents[resolvedPath] != nil: - return resolvedPath, NewBytesSource(resolvedPath, builtinAgents[resolvedPath]) + return resolvedPath, NewBytesSource(resolvedPath, builtinAgents[resolvedPath], envProvider) case IsURLReference(resolvedPath): - // URL-encode the URL to make it safe for use as a map key - return url.QueryEscape(resolvedPath), NewURLSource(resolvedPath, envProvider) + return resolvedPath, NewURLSource(resolvedPath, envProvider) case isLocalFile(resolvedPath): - return fileNameWithoutExt(resolvedPath), NewFileSource(resolvedPath) + return fileNameWithoutExt(resolvedPath), NewFileSource(resolvedPath, envProvider) default: - return reference.OciRefToFilename(resolvedPath), NewOCISource(resolvedPath) + return reference.OciRefToFilename(resolvedPath), NewOCISource(resolvedPath, envProvider) } } @@ -127,7 +125,7 @@ func resolveDirectory(dirPath string, envProvider environment.Provider) (Sources continue } ext := strings.ToLower(filepath.Ext(entry.Name())) - if ext != ".yaml" && ext != ".yml" && ext != ".hcl" { + if ext != ".yaml" && ext != ".yml" { continue } a := filepath.Join(dirPath, entry.Name()) @@ -200,8 +198,8 @@ func IsOCIReference(input string) bool { // isLocalFile checks if the input is a local file func isLocalFile(input string) bool { ext := strings.ToLower(filepath.Ext(input)) - // Check for known config file extensions or file descriptors - if ext == ".yaml" || ext == ".yml" || ext == ".hcl" || strings.HasPrefix(input, "/dev/fd/") { + // Check for YAML file extensions or file descriptors + if ext == ".yaml" || ext == ".yml" || strings.HasPrefix(input, "/dev/fd/") { return true } // Check if it exists as a file on disk diff --git a/pkg/config/sources.go b/pkg/config/sources.go index 62256ff50..970071b16 100644 --- a/pkg/config/sources.go +++ b/pkg/config/sources.go @@ -14,7 +14,6 @@ import ( "path/filepath" "slices" "strings" - "time" "github.com/docker/docker-agent/pkg/content" "github.com/docker/docker-agent/pkg/environment" @@ -27,21 +26,28 @@ type Source interface { Name() string ParentDir() string Read(ctx context.Context) ([]byte, error) + EnvProvider() environment.Provider } type Sources map[string]Source // fileSource is used to load an agent configuration from a YAML file. type fileSource struct { - path string + path string + envProvider environment.Provider } -func NewFileSource(path string) Source { +func NewFileSource(path string, envProvider environment.Provider) Source { return fileSource{ - path: path, + path: path, + envProvider: envProvider, } } +func (a fileSource) EnvProvider() environment.Provider { + return a.envProvider +} + func (a fileSource) Name() string { return a.path } @@ -69,17 +75,23 @@ func (a fileSource) Read(context.Context) ([]byte, error) { // bytesSource is used to load an agent configuration from a []byte. type bytesSource struct { - name string - data []byte + name string + data []byte + envProvider environment.Provider } -func NewBytesSource(name string, data []byte) Source { +func NewBytesSource(name string, data []byte, envProvider environment.Provider) Source { return bytesSource{ - name: name, - data: data, + name: name, + data: data, + envProvider: envProvider, } } +func (a bytesSource) EnvProvider() environment.Provider { + return a.envProvider +} + func (a bytesSource) Name() string { return a.name } @@ -94,15 +106,21 @@ func (a bytesSource) Read(context.Context) ([]byte, error) { // ociSource is used to load an agent configuration from an OCI artifact. type ociSource struct { - reference string + reference string + envProvider environment.Provider } -func NewOCISource(reference string) Source { +func NewOCISource(reference string, envProvider environment.Provider) Source { return ociSource{ - reference: reference, + reference: reference, + envProvider: envProvider, } } +func (a ociSource) EnvProvider() environment.Provider { + return a.envProvider +} + func (a ociSource) Name() string { return a.reference } @@ -134,7 +152,7 @@ func (a ociSource) Read(ctx context.Context) ([]byte, error) { // the artifact locally, serve it directly without any network call. if remote.IsDigestReference(a.reference) { if data, loadErr := loadArtifact(store, storeKey); loadErr == nil { - slog.DebugContext(ctx, "Serving digest-pinned OCI artifact from cache", "ref", a.reference) + slog.Debug("Serving digest-pinned OCI artifact from cache", "ref", a.reference) return data, nil } } @@ -147,7 +165,7 @@ func (a ociSource) Read(ctx context.Context) ([]byte, error) { if !hasLocal { return nil, fmt.Errorf("failed to pull OCI image %s: %w", a.reference, pullErr) } - slog.DebugContext(ctx, "Failed to check for OCI reference updates, using cached version", + slog.Debug("Failed to check for OCI reference updates, using cached version", "ref", a.reference, "error", pullErr) } @@ -162,7 +180,7 @@ func (a ociSource) Read(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("failed to load agent from OCI source %s: %w", a.reference, err) } - slog.WarnContext(ctx, "Local OCI store corrupted, forcing re-pull", "ref", a.reference) + slog.Warn("Local OCI store corrupted, forcing re-pull", "ref", a.reference) if _, pullErr := remote.Pull(ctx, a.reference, true); pullErr != nil { return nil, fmt.Errorf("failed to force re-pull OCI image %s: %w", a.reference, pullErr) } @@ -193,11 +211,6 @@ func hasLocalArtifact(store *content.Store, storeKey string) bool { type urlSource struct { url string envProvider environment.Provider - // unsafe disables the HTTPS-only and SSRF dial-time checks. It is set - // only by the test-only constructor newURLSourceForTest (defined in - // sources_test.go), which exists because tests use httptest.NewServer - // (plain HTTP, 127.0.0.1). - unsafe bool } // NewURLSource creates a new URL source. If envProvider is non-nil, it will be used @@ -209,6 +222,10 @@ func NewURLSource(rawURL string, envProvider environment.Provider) Source { } } +func (a urlSource) EnvProvider() environment.Provider { + return a.envProvider +} + func (a urlSource) Name() string { return a.url } @@ -223,12 +240,6 @@ func getURLCacheDir() string { } func (a urlSource) Read(ctx context.Context) ([]byte, error) { - if !a.unsafe { - if err := validateAgentURL(a.url); err != nil { - return nil, err - } - } - cacheDir := getURLCacheDir() urlHash := hashURL(a.url) cachePath := filepath.Join(cacheDir, urlHash) @@ -253,20 +264,11 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { // Add GitHub token authorization for GitHub URLs a.addGitHubAuth(ctx, req) - client := httpclient.NewHTTPClient(ctx) - if !a.unsafe { - client = &http.Client{ - Timeout: 60 * time.Second, - Transport: httpclient.NewSSRFSafeTransport(), - CheckRedirect: httpclient.HTTPSOnlyRedirects(10), - } - } - - resp, err := client.Do(req) + resp, err := httpclient.NewHTTPClient(ctx).Do(req) if err != nil { // Network error - try to use cached version if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil { - slog.DebugContext(ctx, "Network error fetching URL, using cached version", "url", a.url, "error", err) + slog.Debug("Network error fetching URL, using cached version", "url", a.url, "error", err) return cachedData, nil } return nil, fmt.Errorf("fetching %s: %w", a.url, err) @@ -276,7 +278,7 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { // 304 Not Modified - return cached content if resp.StatusCode == http.StatusNotModified { if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil { - slog.DebugContext(ctx, "URL not modified, using cached version", "url", a.url) + slog.Debug("URL not modified, using cached version", "url", a.url) return cachedData, nil } // Cache file missing despite 304, fall through to fetch again @@ -285,7 +287,7 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { if resp.StatusCode != http.StatusOK { // HTTP error - try to use cached version if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil { - slog.DebugContext(ctx, "HTTP error fetching URL, using cached version", "url", a.url, "status", resp.Status) + slog.Debug("HTTP error fetching URL, using cached version", "url", a.url, "status", resp.Status) return cachedData, nil } return nil, fmt.Errorf("fetching %s: %s", a.url, resp.Status) @@ -297,15 +299,15 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { } // Cache the response - if err := os.MkdirAll(cacheDir, 0o700); err == nil { - if err := os.WriteFile(cachePath, data, 0o600); err != nil { - slog.DebugContext(ctx, "Failed to cache URL content", "url", a.url, "error", err) + if err := os.MkdirAll(cacheDir, 0o755); err == nil { + if err := os.WriteFile(cachePath, data, 0o644); err != nil { + slog.Debug("Failed to cache URL content", "url", a.url, "error", err) } // Save ETag if present if etag := resp.Header.Get("ETag"); etag != "" { - if err := os.WriteFile(etagPath, []byte(etag), 0o600); err != nil { - slog.DebugContext(ctx, "Failed to cache ETag", "url", a.url, "error", err) + if err := os.WriteFile(etagPath, []byte(etag), 0o644); err != nil { + slog.Debug("Failed to cache ETag", "url", a.url, "error", err) } } else { // Remove stale ETag file if server no longer provides ETag @@ -352,7 +354,7 @@ func (a urlSource) addGitHubAuth(ctx context.Context, req *http.Request) { } req.Header.Set("Authorization", "Bearer "+token) - slog.DebugContext(ctx, "Added GitHub token authorization to request", "url", a.url) + slog.Debug("Added GitHub token authorization to request", "url", a.url) } // hashURL creates a safe filename from a URL. @@ -365,21 +367,3 @@ func hashURL(rawURL string) string { func IsURLReference(input string) bool { return strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") } - -// validateAgentURL enforces that an agent URL uses HTTPS. SSRF protection -// (rejecting connections to loopback / private / link-local addresses) is -// done at dial time by [httpclient.NewSSRFSafeTransport] so that DNS -// rebinding cannot be used to bypass it. -func validateAgentURL(rawURL string) error { - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid URL %q: %w", rawURL, err) - } - if u.Scheme != "https" { - return fmt.Errorf("refusing to load agent from %q: only https:// URLs are allowed (got scheme %q)", rawURL, u.Scheme) - } - if u.Host == "" { - return fmt.Errorf("invalid URL %q: missing host", rawURL) - } - return nil -} diff --git a/pkg/config/sources_test.go b/pkg/config/sources_test.go index c786b6ef3..b0552045d 100644 --- a/pkg/config/sources_test.go +++ b/pkg/config/sources_test.go @@ -29,7 +29,6 @@ func newURLSourceForTest(rawURL string, envProvider environment.Provider) Source return &urlSource{ url: rawURL, envProvider: envProvider, - unsafe: true, } } diff --git a/pkg/config/validation_test.go b/pkg/config/validation_test.go index 0a0ca5f48..4b82cb496 100644 --- a/pkg/config/validation_test.go +++ b/pkg/config/validation_test.go @@ -21,11 +21,11 @@ agents: err := tmpRoot.WriteFile("valid.yaml", []byte(validConfig), 0o644) require.NoError(t, err) - cfg, err := Load(t.Context(), NewFileSource(filepath.Join(tmp, "valid.yaml"))) + cfg, err := Load(t.Context(), NewFileSource(filepath.Join(tmp, "valid.yaml"), nil)) require.NoError(t, err) require.NotNil(t, cfg) - _, err = Load(t.Context(), NewFileSource(filepath.Join(tmp, "../../../etc/passwd"))) //nolint: gocritic // testing invalid path + _, err = Load(t.Context(), NewFileSource(filepath.Join(tmp, "../../../etc/passwd"), nil)) //nolint: gocritic // testing invalid path require.Error(t, err) } @@ -58,7 +58,7 @@ func TestValidationErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := Load(t.Context(), NewFileSource(filepath.Join("testdata", tt.path))) + _, err := Load(t.Context(), NewFileSource(filepath.Join("testdata", tt.path), nil)) require.Error(t, err) }) } @@ -72,7 +72,7 @@ agents: root: model: openai/gpt-4 ` - _, err := Load(t.Context(), NewBytesSource("test", []byte(cfg))) + _, err := Load(t.Context(), NewBytesSource("test", []byte(cfg), nil)) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported config version: 99") assert.Contains(t, err.Error(), "valid versions") @@ -114,53 +114,26 @@ func TestValidSkillsConfiguration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - cfg, err := Load(t.Context(), NewFileSource(filepath.Join("testdata", tt.path))) + cfg, err := Load(t.Context(), NewFileSource(filepath.Join("testdata", tt.path), nil)) require.NoError(t, err) require.NotNil(t, cfg) }) } } -func TestSkillsConfigRejectsEmptyEntry(t *testing.T) { +func TestInvalidSkillsSources(t *testing.T) { t.Parallel() - // Empty entries in the skills list should be rejected. cfgStr := `version: "5" agents: root: model: openai/gpt-4o skills: - - local - - "" + - invalid_source toolsets: - type: filesystem ` - _, err := Load(t.Context(), NewBytesSource("test", []byte(cfgStr))) + _, err := Load(t.Context(), NewBytesSource("test", []byte(cfgStr), nil)) require.Error(t, err) - assert.Contains(t, err.Error(), "empty skills entry") -} - -func TestSkillsNameFilter(t *testing.T) { - t.Parallel() - - // A string that is not "local" and not a URL is interpreted as a skill - // name to include. This must load successfully — the filter simply keeps - // only matching skills at runtime. - cfgStr := `version: "7" -agents: - root: - model: openai/gpt-4o - skills: - - git - - docker - toolsets: - - type: filesystem -` - cfg, err := Load(t.Context(), NewBytesSource("test", []byte(cfgStr))) - require.NoError(t, err) - agent, ok := cfg.Agents.Lookup("root") - require.True(t, ok) - require.True(t, agent.Skills.Enabled()) - require.True(t, agent.Skills.HasLocal()) - assert.Equal(t, []string{"git", "docker"}, agent.Skills.Include) + assert.Contains(t, err.Error(), "unknown skills source") } diff --git a/pkg/creator/agent.go b/pkg/creator/agent.go index f261229bb..b345431ea 100644 --- a/pkg/creator/agent.go +++ b/pkg/creator/agent.go @@ -44,7 +44,7 @@ func Agent(ctx context.Context, runConfig *config.RuntimeConfig, modelNameOverri return teamloader.Load( ctx, - config.NewBytesSource("creator", configYAML), + config.NewBytesSource("creator", configYAML, nil), runConfig, teamloader.WithModelOverrides([]string{modelNameOverride}), ) diff --git a/pkg/creator/agent_test.go b/pkg/creator/agent_test.go index 6a40d0ba8..95e861af5 100644 --- a/pkg/creator/agent_test.go +++ b/pkg/creator/agent_test.go @@ -19,7 +19,7 @@ func TestBuildCreatorConfigYAML(t *testing.T) { require.NoError(t, err) // Verify it can be loaded by the config loader - cfg, err := config.Load(t.Context(), config.NewBytesSource("test", data)) + cfg, err := config.Load(t.Context(), config.NewBytesSource("test", data, nil)) require.NoError(t, err) // Verify the config structure @@ -113,7 +113,7 @@ func TestBuildCreatorConfigYAML_MultilineStrings(t *testing.T) { t.Logf("YAML output:\n%s", yamlStr) // Verify the YAML can be parsed - cfg, err := config.Load(t.Context(), config.NewBytesSource("test", data)) + cfg, err := config.Load(t.Context(), config.NewBytesSource("test", data, nil)) require.NoError(t, err) // Verify the instruction is preserved correctly diff --git a/pkg/environment/expand.go b/pkg/environment/expand.go index bf0db0c21..2a3c54253 100644 --- a/pkg/environment/expand.go +++ b/pkg/environment/expand.go @@ -4,9 +4,15 @@ import ( "context" "fmt" "os" + "regexp" "slices" + "strings" ) +// envDotPattern matches ${env.VAR} and captures VAR. +// This normalizes JS-style env access to POSIX-style before os.Expand. +var envDotPattern = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`) + func ExpandAll(ctx context.Context, values []string, env Provider) ([]string, error) { var expandedEnv []string @@ -22,23 +28,63 @@ func ExpandAll(ctx context.Context, values []string, env Provider) ([]string, er return expandedEnv, nil } +// Expand resolves environment variable references in value using the provided Provider. +// It accepts three equivalent syntaxes: $VAR, ${VAR}, and ${env.VAR}. +// ~ expansion is intentionally excluded; use path.ExpandHome for path fields. +// +// If a referenced variable is not set, Expand returns an ErrMissingVarsError error +// wrapping all missing names, but still returns the partially-expanded string +// so callers can decide whether to hard-fail or warn. func Expand(ctx context.Context, value string, env Provider) (string, error) { - var err error + if env == nil { + return value, nil + } + + // Normalize ${env.VAR} → ${VAR} so os.Expand handles both uniformly, + // but only for simple cases without additional JS logic. + normalized := envDotPattern.ReplaceAllString(value, `${$1}`) + + var missing []string + expanded := os.Expand(normalized, func(name string) string { + if name == "" { + return "" + } + if name == "$" { + return "$" + } + + // If it's a complex JS expression (contains spaces, dots, quotes, etc), + // we leave it untouched for Layer 2 (JS engine) to handle. + // Valid POSIX env names only contain alphanumeric characters and underscores. + if strings.ContainsAny(name, " .|'\"(){}[],+-*/=!<>?^&%@~`\\#") { + return "${" + name + "}" + } - expanded := os.Expand(value, func(name string) string { v, found := env.Get(ctx, name) if !found { - err = fmt.Errorf("environment variable %q not set", name) + missing = append(missing, name) + return "" // match os.ExpandEnv behavior: empty on missing } return v }) - if err != nil { - return "", err + + if len(missing) > 0 { + return expanded, &ErrMissingVarsError{Names: missing} } return expanded, nil } +// ErrMissingVarsError is returned when one or more referenced variables are not set. +// It is a distinct type so callers can decide severity (hard fail vs warn). +type ErrMissingVarsError struct { + Names []string +} + +func (e *ErrMissingVarsError) Error() string { + return "environment variables not set: " + strings.Join(e.Names, ", ") +} + func ToValues(envMap map[string]string) []string { var values []string for k, v := range envMap { diff --git a/pkg/js/eval.go b/pkg/js/eval.go new file mode 100644 index 000000000..51ea6d200 --- /dev/null +++ b/pkg/js/eval.go @@ -0,0 +1,94 @@ +package js + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "slices" + "strings" + + "github.com/docker/docker-agent/pkg/tools" +) + +// Evaluator handles JavaScript expression evaluation in strings. +type Evaluator struct { + tools []tools.Tool +} + +// NewEvaluator creates a new Evaluator with the given tools. +func NewEvaluator(agentTools []tools.Tool) *Evaluator { + return &Evaluator{ + tools: agentTools, + } +} + +// Evaluate finds and evaluates ${...} JavaScript expressions in the input string. +// args are available as the 'args' array in JavaScript. +func (e *Evaluator) Evaluate(ctx context.Context, input string, args []string) string { + if !strings.Contains(input, "${") { + return input + } + + vm := newVM() + if args == nil { + args = []string{} + } + _ = vm.Set("args", args) + + // Bind tools to VM + for _, tool := range e.tools { + _ = vm.Set(tool.Name, e.createToolCaller(ctx, tool)) + } + + slog.Debug("Evaluating JS template", "input", input) + + return runExpansion(vm, input) +} + +// createToolCaller creates a JavaScript function that calls the given tool. +func (e *Evaluator) createToolCaller(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) { + return func(args map[string]any) (string, error) { + var toolArgs struct { + Required []string `json:"required"` + } + + if err := tools.ConvertSchema(tool.Parameters, &toolArgs); err != nil { + return "", err + } + + // Filter out nil values for non-required arguments + nonNilArgs := make(map[string]any) + for k, v := range args { + if slices.Contains(toolArgs.Required, k) || v != nil { + nonNilArgs[k] = v + } + } + + arguments, err := json.Marshal(nonNilArgs) + if err != nil { + return "", err + } + + toolCall := tools.ToolCall{ + ID: "jseval_" + tool.Name, + Type: "function", + Function: tools.FunctionCall{ + Name: tool.Name, + Arguments: string(arguments), + }, + } + + if tool.Handler == nil { + return "", fmt.Errorf("tool '%s' has no handler", tool.Name) + } + + // Use the parent context directly, relying on its cancellation/timeout + result, err := tool.Handler(ctx, toolCall) + if err != nil { + return "", err + } + + return result.Output, nil + } +} diff --git a/pkg/js/expand.go b/pkg/js/expand.go index cda94aa1e..b77ace98d 100644 --- a/pkg/js/expand.go +++ b/pkg/js/expand.go @@ -2,97 +2,45 @@ package js import ( "context" - "encoding/json" - "fmt" - "log/slog" - "slices" "strings" "github.com/dop251/goja" "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/environment" - "github.com/docker/docker-agent/pkg/tools" ) // newVM creates a new Goja JavaScript runtime. var newVM = goja.New -// Expander expands JavaScript template literals in strings. -// It can be configured with an environment provider for ${env.X} access -// and/or agent tools for ${tool({...})} calls. +// Expander evaluates JavaScript expressions. type Expander struct { - env environment.Provider - tools []tools.Tool + env environment.Provider } -// NewJsExpander creates a new Expander with the given environment provider. +// NewJsExpander creates a new Expander. func NewJsExpander(env environment.Provider) *Expander { return &Expander{env: env} } -// NewEvaluator creates a new Expander with the given tools (for command evaluation). -func NewEvaluator(agentTools []tools.Tool) *Expander { - return &Expander{tools: agentTools} -} - -// dynamicLookup implements goja.DynamicObject for lazy key-value access. -type dynamicLookup struct { - vm *goja.Runtime - lookup func(string) string -} - -func (d *dynamicLookup) Get(k string) goja.Value { return d.vm.ToValue(d.lookup(k)) } -func (*dynamicLookup) Set(string, goja.Value) bool { return false } -func (*dynamicLookup) Has(string) bool { return true } -func (*dynamicLookup) Delete(string) bool { return true } -func (*dynamicLookup) Keys() []string { return nil } - -// newVMWithBindings creates a new JS runtime with env and tools pre-bound. -func (exp *Expander) newVMWithBindings(ctx context.Context) *goja.Runtime { +func (exp *Expander) newVM(ctx context.Context) *goja.Runtime { vm := newVM() - if exp.env != nil { _ = vm.Set("env", vm.NewDynamicObject(&dynamicLookup{ vm: vm, lookup: func(k string) string { v, _ := exp.env.Get(ctx, k); return v }, })) } - - for _, tool := range exp.tools { - _ = vm.Set(tool.Name, createToolCaller(ctx, tool)) - } - return vm } -// Evaluate finds and evaluates ${...} JavaScript expressions in the input string. -// args are available as the 'args' array in JavaScript. -func (exp *Expander) Evaluate(ctx context.Context, input string, args []string) string { - if !strings.Contains(input, "${") { - return input - } - - vm := exp.newVMWithBindings(ctx) - if args == nil { - args = []string{} - } - _ = vm.Set("args", args) - - slog.DebugContext(ctx, "Evaluating JS template", "input", input) - - return runExpansion(vm, input) -} - -// Expand expands JavaScript template literals using the provided values map. -// The values are bound as top-level variables in the JS runtime alongside -// env and tools bindings. +// Expand evaluates JavaScript template literals. func (exp *Expander) Expand(ctx context.Context, text string, values map[string]string) string { if !strings.Contains(text, "${") { return text } - vm := exp.newVMWithBindings(ctx) + vm := exp.newVM(ctx) for k, v := range values { _ = vm.Set(k, v) } @@ -100,13 +48,13 @@ func (exp *Expander) Expand(ctx context.Context, text string, values map[string] return runExpansion(vm, text) } -// ExpandMap expands JavaScript template literals in all values of the given map. +// ExpandMap evaluates JavaScript template literals in all values of the given map. func (exp *Expander) ExpandMap(ctx context.Context, kv map[string]string) map[string]string { if kv == nil { return nil } - vm := exp.newVMWithBindings(ctx) + vm := exp.newVM(ctx) expanded := make(map[string]string, len(kv)) for k, v := range kv { @@ -115,13 +63,13 @@ func (exp *Expander) ExpandMap(ctx context.Context, kv map[string]string) map[st return expanded } -// ExpandCommands expands JavaScript template literals in all command fields. +// ExpandCommands evaluates JavaScript template literals in all command fields. func (exp *Expander) ExpandCommands(ctx context.Context, cmds types.Commands) types.Commands { if cmds == nil { return nil } - vm := exp.newVMWithBindings(ctx) + vm := exp.newVM(ctx) expanded := make(types.Commands, len(cmds)) for k, cmd := range cmds { @@ -133,16 +81,15 @@ func (exp *Expander) ExpandCommands(ctx context.Context, cmds types.Commands) ty return expanded } -// ExpandMapFunc expands JavaScript template literals in map values. -// It binds a dynamic object with the given name to the JS runtime, -// using lookup to resolve property accesses. Each value is optionally -// preprocessed with preprocess before expansion (pass nil to skip). +// ExpandMapFunc evaluates JavaScript template literals in map values. func ExpandMapFunc(values map[string]string, objName string, lookup, preprocess func(string) string) map[string]string { vm := newVM() - _ = vm.Set(objName, vm.NewDynamicObject(&dynamicLookup{ - vm: vm, - lookup: lookup, - })) + if lookup != nil { + _ = vm.Set(objName, vm.NewDynamicObject(&dynamicLookup{ + vm: vm, + lookup: lookup, + })) + } resolved := make(map[string]string, len(values)) for k, v := range values { @@ -154,56 +101,19 @@ func ExpandMapFunc(values map[string]string, objName string, lookup, preprocess return resolved } -// createToolCaller creates a JavaScript function that calls the given tool. -func createToolCaller(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) { - return func(args map[string]any) (string, error) { - var toolArgs struct { - Required []string `json:"required"` - } - - if err := tools.ConvertSchema(tool.Parameters, &toolArgs); err != nil { - return "", err - } - - // Filter out nil values for non-required arguments - nonNilArgs := make(map[string]any) - for k, v := range args { - if slices.Contains(toolArgs.Required, k) || v != nil { - nonNilArgs[k] = v - } - } - - arguments, err := json.Marshal(nonNilArgs) - if err != nil { - return "", err - } - - toolCall := tools.ToolCall{ - ID: "jseval_" + tool.Name, - Type: "function", - Function: tools.FunctionCall{ - Name: tool.Name, - Arguments: string(arguments), - }, - } - - if tool.Handler == nil { - return "", fmt.Errorf("tool '%s' has no handler", tool.Name) - } - - result, err := tool.Handler(ctx, toolCall) - if err != nil { - return "", err - } - - return result.Output, nil - } +// dynamicLookup implements goja.DynamicObject for lazy key-value access. +type dynamicLookup struct { + vm *goja.Runtime + lookup func(string) string } +func (d *dynamicLookup) Get(k string) goja.Value { return d.vm.ToValue(d.lookup(k)) } +func (*dynamicLookup) Set(string, goja.Value) bool { return false } +func (*dynamicLookup) Has(string) bool { return true } +func (*dynamicLookup) Delete(string) bool { return true } +func (*dynamicLookup) Keys() []string { return nil } + // runExpansion executes the template string using the provided Goja runtime. -// If the full template literal evaluation fails (e.g. because one expression -// references an undefined variable), it falls back to evaluating each ${...} -// expression independently so that successful expressions are still expanded. func runExpansion(vm *goja.Runtime, text string) string { // Escape backslashes first, then backticks escaped := strings.ReplaceAll(text, "\\", "\\\\") diff --git a/pkg/js/expand_test.go b/pkg/js/expand_test.go index 85b7ebbaf..1f018f488 100644 --- a/pkg/js/expand_test.go +++ b/pkg/js/expand_test.go @@ -1,13 +1,9 @@ package js import ( - "context" "testing" "github.com/stretchr/testify/assert" - - "github.com/docker/docker-agent/pkg/config/types" - "github.com/docker/docker-agent/pkg/tools" ) func TestExpand(t *testing.T) { @@ -16,121 +12,29 @@ func TestExpand(t *testing.T) { tests := []struct { name string commands string - envVars map[string]string + values map[string]string expected string }{ { name: "no placeholder", commands: "List all files", - envVars: map[string]string{}, + values: map[string]string{}, expected: "List all files", }, { - name: "single placeholder", - commands: "Say hello to ${env.USER}", - envVars: map[string]string{"USER": "alice"}, + name: "simple substitution", + commands: "Say hello to ${USER}", + values: map[string]string{"USER": "alice"}, expected: "Say hello to alice", }, - { - name: "multiple placeholders", - commands: "Analyze ${env.PROJECT_NAME} in ${env.ENVIRONMENT}", - envVars: map[string]string{"PROJECT_NAME": "myproject", "ENVIRONMENT": "production"}, - expected: "Analyze myproject in production", - }, - { - name: "default value", - commands: "Say hello to ${env.USER || 'Bob'}", - envVars: map[string]string{}, - expected: "Say hello to Bob", - }, - { - name: "missing env var expands to empty string", - commands: "Check ${env.MISSING_VAR} status", - envVars: map[string]string{}, - expected: "Check status", - }, - { - name: "ternary operator", - commands: "${env.NAME == 'bob' ? 'Yes' : 'No'}", - envVars: map[string]string{"NAME": "bob"}, - expected: "Yes", - }, - { - name: "default value (found)", - commands: "${env.NAME || 'UNKNOWN'}", - envVars: map[string]string{"NAME": "bob"}, - expected: "bob", - }, - { - name: "default value (not found)", - commands: "${env.NAME || 'UNKNOWN'}", - envVars: map[string]string{}, - expected: "UNKNOWN", - }, - { - name: "backticks in template (markdown code fence)", - commands: "Here is code:\n```\n${env.CODE}\n```\nEnd.", - envVars: map[string]string{"CODE": "fmt.Println()"}, - expected: "Here is code:\n```\nfmt.Println()\n```\nEnd.", - }, - { - name: "multiple backticks", - commands: "Use `inline` and ```block``` code", - envVars: map[string]string{}, - expected: "Use `inline` and ```block``` code", - }, - { - name: "single backslash", - commands: "test\\value", - envVars: map[string]string{}, - expected: "test\\value", - }, - { - name: "backslash n (not newline)", - commands: "test\\nvalue", - envVars: map[string]string{}, - expected: "test\\nvalue", - }, - { - name: "backslash t (not tab)", - commands: "test\\tvalue", - envVars: map[string]string{}, - expected: "test\\tvalue", - }, - { - name: "windows path", - commands: "C:\\Users\\Alice\\Documents", - envVars: map[string]string{}, - expected: "C:\\Users\\Alice\\Documents", - }, - { - name: "network path", - commands: "\\\\server\\share\\file", - envVars: map[string]string{}, - expected: "\\\\server\\share\\file", - }, - { - name: "multiple backslashes", - commands: "test\\\\value", - envVars: map[string]string{}, - expected: "test\\\\value", - }, - { - name: "regex pattern with backslashes", - commands: "\\d+\\.\\d+", - envVars: map[string]string{}, - expected: "\\d+\\.\\d+", - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - env := testEnvProvider(tt.envVars) - - expander := NewJsExpander(&env) - result := expander.Expand(t.Context(), tt.commands, nil) + expander := NewJsExpander(nil) + result := expander.Expand(t.Context(), tt.commands, tt.values) assert.Equal(t, tt.expected, result) }) @@ -140,30 +44,22 @@ func TestExpand(t *testing.T) { func TestExpandMap(t *testing.T) { t.Parallel() - env := testEnvProvider(map[string]string{ - "USER": "alice", - }) - - expander := NewJsExpander(&env) + expander := NewJsExpander(nil) result := expander.ExpandMap(t.Context(), map[string]string{ "none": "List all files", - "simple": "Say hello to ${env.USER}", + "simple": "Say hello to ${USER}", }) assert.Equal(t, map[string]string{ "none": "List all files", - "simple": "Say hello to alice", + "simple": "Say hello to ${USER}", // values is nil, so no expansion }, result) } func TestExpandMap_Reuse(t *testing.T) { t.Parallel() - env := testEnvProvider(map[string]string{ - "USER": "alice", - }) - - expander := NewJsExpander(&env) + expander := NewJsExpander(nil) result := expander.ExpandMap(t.Context(), map[string]string{ "none": "List all files", @@ -173,19 +69,17 @@ func TestExpandMap_Reuse(t *testing.T) { }, result) result = expander.ExpandMap(t.Context(), map[string]string{ - "simple": "Say hello to ${env.USER}", + "simple": "Say hello to ${USER}", }) assert.Equal(t, map[string]string{ - "simple": "Say hello to alice", + "simple": "Say hello to ${USER}", // values is nil }, result) } func TestExpandMap_Empty(t *testing.T) { t.Parallel() - env := testEnvProvider(map[string]string{}) - - expander := NewJsExpander(&env) + expander := NewJsExpander(nil) result := expander.ExpandMap(t.Context(), map[string]string{}) assert.Empty(t, result) @@ -236,118 +130,9 @@ func TestExpandString(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - env := testEnvProvider(map[string]string{}) - expander := NewJsExpander(&env) + expander := NewJsExpander(nil) result := expander.Expand(t.Context(), tt.template, tt.values) assert.Equal(t, tt.expected, result) }) } } - -type testEnvProvider map[string]string - -func (p *testEnvProvider) Get(_ context.Context, name string) (string, bool) { - val, found := (*p)[name] - return val, found -} - -// TestExpandCommandsThenEvaluate verifies the two-phase flow that slash commands go through: -// 1. ExpandCommands at config load time (env-only, no tools) -// 2. Evaluate at runtime (tools available) -// This catches regressions where one phase corrupts expressions needed by the other. -func TestExpandCommandsThenEvaluate(t *testing.T) { - t.Parallel() - - env := testEnvProvider(map[string]string{"USER": "alice"}) - - mockTools := []tools.Tool{ - { - Name: "shell", - Handler: func(_ context.Context, tc tools.ToolCall) (*tools.ToolCallResult, error) { - return tools.ResultSuccess("lint output"), nil - }, - }, - } - - cmds := types.Commands{ - "fix-lint": { - Description: "Fix lint", - Instruction: "User: ${env.USER}\nLint: ${shell({cmd: \"task lint\"})}\n${unknown_mcp_tool()}", - }, - } - - // Phase 1: ExpandCommands with env only (no tools) - expander := NewJsExpander(&env) - expanded := expander.ExpandCommands(t.Context(), cmds) - - // env.USER should be expanded, tool calls should be preserved - assert.Contains(t, expanded["fix-lint"].Instruction, "User: alice") - assert.Contains(t, expanded["fix-lint"].Instruction, "${shell({cmd: \"task lint\"})}") // preserved - assert.Contains(t, expanded["fix-lint"].Instruction, "${unknown_mcp_tool()}") // preserved - - // Phase 2: Evaluate with tools (no env) - evaluator := NewEvaluator(mockTools) - result := evaluator.Evaluate(t.Context(), expanded["fix-lint"].Instruction, nil) - - // shell should now be expanded, unknown tool should be preserved - assert.Contains(t, result, "User: alice") - assert.Contains(t, result, "lint output") - assert.Contains(t, result, "${unknown_mcp_tool()}") - assert.NotContains(t, result, "${shell") -} - -func TestFindClosingBrace(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - text string - pos int - expected int - }{ - { - name: "simple", - text: "${foo}", - pos: 2, - expected: 5, - }, - { - name: "nested braces", - text: "${shell({cmd: \"ls\"})}", - pos: 2, - expected: len("${shell({cmd: \"ls\"})}") - 1, - }, - { - name: "closing brace inside quotes", - text: `${shell({cmd: "echo }"})}`, - pos: 2, - expected: len(`${shell({cmd: "echo }"})}`) - 1, - }, - { - name: "escaped quote inside string", - text: `${shell({cmd: "echo \"}"})}`, - pos: 2, - expected: len(`${shell({cmd: "echo \"}"})}`) - 1, - }, - { - name: "unclosed", - text: "${foo", - pos: 2, - expected: -1, - }, - { - name: "unclosed nested", - text: "${foo({bar: 1}", - pos: 2, - expected: -1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - result := findClosingBrace(tt.text, tt.pos) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/pkg/server/source_loader.go b/pkg/server/source_loader.go index b78a5c4e3..8489018e8 100644 --- a/pkg/server/source_loader.go +++ b/pkg/server/source_loader.go @@ -7,6 +7,7 @@ import ( "time" "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/environment" ) type sourceLoader struct { @@ -46,6 +47,10 @@ func (sl *sourceLoader) ParentDir() string { return sl.inner.ParentDir() } +func (sl *sourceLoader) EnvProvider() environment.Provider { + return sl.inner.EnvProvider() +} + func (sl *sourceLoader) Read(_ context.Context) ([]byte, error) { sl.mu.RLock() defer sl.mu.RUnlock() diff --git a/pkg/server/source_loader_test.go b/pkg/server/source_loader_test.go index 33e76cc59..b0cccd05d 100644 --- a/pkg/server/source_loader_test.go +++ b/pkg/server/source_loader_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/environment" ) type mockSource struct { @@ -39,6 +41,10 @@ func (m *mockSource) Read(context.Context) ([]byte, error) { return m.data, nil } +func (m *mockSource) EnvProvider() environment.Provider { + return nil +} + func (m *mockSource) setData(data []byte) { m.mu.Lock() defer m.mu.Unlock() diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 760fc18d0..4f1e6eb83 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -103,7 +103,7 @@ func TestLoadExamples(t *testing.T) { // Use a bytes source (ParentDir == "") plus a temp WorkingDir so // toolsets that write to disk (memory, RAG, cache, ...) land in // the temp dir instead of the examples/ tree. - agentSource := config.NewBytesSource(agentFilename, data) + agentSource := config.NewBytesSource(agentFilename, data, nil) runConfig := &config.RuntimeConfig{} runConfig.WorkingDir = t.TempDir() @@ -229,7 +229,7 @@ func TestInstructionExpansion(t *testing.T) { t.Setenv("OPENAI_API_KEY", "dummy") t.Setenv("USER", "alice") - agentSource, err := config.Resolve("testdata/instruction-expansion.yaml", nil) + agentSource, err := config.Resolve("testdata/instruction-expansion.yaml", environment.NewDefaultProvider()) require.NoError(t, err) team, err := Load(t.Context(), agentSource, &config.RuntimeConfig{}) @@ -538,7 +538,7 @@ agents: ` require.NoError(t, os.WriteFile(configPath, []byte(configContent), 0o600)) - source := config.NewFileSource(configPath) + source := config.NewFileSource(configPath, nil) runConfig := &config.RuntimeConfig{} runConfig.WorkingDir = tmpDir