diff --git a/cmd/docker-mcp/commands/gateway.go b/cmd/docker-mcp/commands/gateway.go index 1c1c6fdd..284d2262 100644 --- a/cmd/docker-mcp/commands/gateway.go +++ b/cmd/docker-mcp/commands/gateway.go @@ -75,16 +75,20 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli, features featur len(mcpRegistryUrls) > 0 || len(options.OciRef) > 0 || (options.SecretsPath != "docker-desktop" && !strings.HasPrefix(options.SecretsPath, "docker-desktop:")) { // We're in legacy mode, so we can't use the working set feature - if options.WorkingSet != "" { - return fmt.Errorf("cannot use --profile with --servers, --enable-all-servers, --catalog, --additional-catalog, --registry, --additional-registry, --config, --additional-config, --tools-config, --additional-tools-config, --secrets, --oci-ref, --mcp-registry flags") + if options.WorkingSet != "" || options.WorkingSetFile != "" { + return fmt.Errorf("cannot use --profile or --profile-file with --servers, --enable-all-servers, --catalog, --additional-catalog, --registry, --additional-registry, --config, --additional-config, --tools-config, --additional-tools-config, --secrets, --oci-ref, --mcp-registry flags") } // Make sure to default the options in legacy mode setLegacyDefaults(&options) - } else if options.WorkingSet == "" { + } else if options.WorkingSet == "" && options.WorkingSetFile == "" { // ELSE we're in working set mode, // so IF no profile specified, use the default profile options.WorkingSet = "default" } + + if options.WorkingSet != "" && options.WorkingSetFile != "" { + return fmt.Errorf("cannot use both --profile and --profile-file at the same time") + } } // Check if OAuth interceptor feature is enabled @@ -183,7 +187,8 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli, features featur runCmd.Flags().StringSliceVar(&options.ServerNames, "servers", nil, "Names of the servers to enable (if non empty, ignore --registry flag)") if features.IsProfilesFeatureEnabled() { - runCmd.Flags().StringVar(&options.WorkingSet, "profile", "", "Profile ID to use (mutually exclusive with --servers and --enable-all-servers)") + runCmd.Flags().StringVar(&options.WorkingSet, "profile", "", "Profile ID to use (mutually exclusive with --profile-file, --servers and --enable-all-servers)") + runCmd.Flags().StringVar(&options.WorkingSetFile, "profile-file", "", "Profile file to use (mutually exclusive with --profile, --servers and --enable-all-servers)") } runCmd.Flags().BoolVar(&enableAllServers, "enable-all-servers", false, "Enable all servers in the catalog (instead of using individual --servers options)") runCmd.Flags().StringSliceVar(&options.CatalogPath, "catalog", options.CatalogPath, "Paths to docker catalogs (absolute or relative to ~/.docker/mcp/catalogs/)") diff --git a/pkg/gateway/config.go b/pkg/gateway/config.go index 3c60d557..50358349 100644 --- a/pkg/gateway/config.go +++ b/pkg/gateway/config.go @@ -5,6 +5,7 @@ import "github.com/docker/mcp-gateway/pkg/catalog" type Config struct { Options WorkingSet string + WorkingSetFile string ServerNames []string CatalogPath []string ConfigPath []string diff --git a/pkg/gateway/configuration_workingset.go b/pkg/gateway/configuration_workingset.go index 59cc27b9..fd396b5c 100644 --- a/pkg/gateway/configuration_workingset.go +++ b/pkg/gateway/configuration_workingset.go @@ -2,7 +2,6 @@ package gateway import ( "context" - "database/sql" "errors" "fmt" "time" @@ -18,29 +17,43 @@ import ( ) type WorkingSetConfiguration struct { - WorkingSet string - ociService oci.Service - docker docker.Client + workingSetID string + workingSetFile string + ociService oci.Service + docker docker.Client } -func NewWorkingSetConfiguration(workingSet string, ociService oci.Service, docker docker.Client) *WorkingSetConfiguration { +func NewWorkingSetConfiguration(workingSetID string, workingSetFile string, ociService oci.Service, docker docker.Client) *WorkingSetConfiguration { return &WorkingSetConfiguration{ - WorkingSet: workingSet, - ociService: ociService, - docker: docker, + workingSetID: workingSetID, + workingSetFile: workingSetFile, + ociService: ociService, + docker: docker, } } func (c *WorkingSetConfiguration) Read(ctx context.Context) (Configuration, chan Configuration, func() error, error) { - dao, err := db.New() - if err != nil { - return Configuration{}, nil, nil, fmt.Errorf("failed to create database client: %w", err) - } + var loader WorkingSetLoader + cleanup := func() error { return nil } + + if c.workingSetID != "" { + dao, err := db.New() + if err != nil { + return Configuration{}, nil, nil, fmt.Errorf("failed to create database client: %w", err) + } + cleanup = func() error { return dao.Close() } - // Do migration from legacy files - migrate.MigrateConfig(ctx, c.docker, dao) + // Do migration from legacy files + migrate.MigrateConfig(ctx, c.docker, dao) - configuration, err := c.readOnce(ctx, dao) + loader = NewWorkingSetDatabaseLoader(c.workingSetID, dao) + } else if c.workingSetFile != "" { + loader = NewWorkingSetFileLoader(c.workingSetFile, c.ociService) + } else { + return Configuration{}, nil, nil, fmt.Errorf("no working set ID or file provided") + } + + configuration, err := c.readOnce(ctx, loader) if err != nil { return Configuration{}, nil, nil, err } @@ -48,28 +61,26 @@ func (c *WorkingSetConfiguration) Read(ctx context.Context) (Configuration, chan // TODO(cody): Stub for now updates := make(chan Configuration) - return configuration, updates, func() error { return nil }, nil + return configuration, updates, cleanup, nil } -func (c *WorkingSetConfiguration) readOnce(ctx context.Context, dao db.DAO) (Configuration, error) { +func (c *WorkingSetConfiguration) readOnce(ctx context.Context, loader WorkingSetLoader) (Configuration, error) { start := time.Now() log.Log("- Reading profile configuration...") - dbWorkingSet, err := dao.GetWorkingSet(ctx, c.WorkingSet) + workingSet, err := loader.ReadWorkingSet(ctx) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, ErrWorkingSetNotFound) { // Special case for the default profile, we're okay with it not existing - if c.WorkingSet == "default" { + if c.workingSetID == "default" { log.Log(" - Default profile not found, using empty configuration") return c.emptyConfiguration() } - return Configuration{}, fmt.Errorf("profile %s not found", c.WorkingSet) + return Configuration{}, err } return Configuration{}, fmt.Errorf("failed to get profile: %w", err) } - workingSet := workingset.NewFromDb(dbWorkingSet) - if err := workingSet.EnsureSnapshotsResolved(ctx, c.ociService); err != nil { return Configuration{}, fmt.Errorf("failed to resolve snapshots: %w", err) } diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index ebe3db5e..fe69d8cc 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -84,8 +84,8 @@ type Gateway struct { func NewGateway(config Config, docker docker.Client) *Gateway { var configurator Configurator - if config.WorkingSet != "" { - configurator = NewWorkingSetConfiguration(config.WorkingSet, oci.NewService(), docker) + if config.WorkingSet != "" || config.WorkingSetFile != "" { + configurator = NewWorkingSetConfiguration(config.WorkingSet, config.WorkingSetFile, oci.NewService(), docker) } else { configurator = &FileBasedConfiguration{ ServerNames: config.ServerNames, diff --git a/pkg/gateway/workingset_loader.go b/pkg/gateway/workingset_loader.go new file mode 100644 index 00000000..dc9ca1d1 --- /dev/null +++ b/pkg/gateway/workingset_loader.go @@ -0,0 +1,64 @@ +package gateway + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + + "github.com/docker/mcp-gateway/pkg/db" + "github.com/docker/mcp-gateway/pkg/oci" + "github.com/docker/mcp-gateway/pkg/workingset" +) + +var ErrWorkingSetNotFound = errors.New("profile not found") + +type WorkingSetLoader interface { + ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error) +} + +type databaseLoader struct { + workingSet string + dao db.DAO +} + +func NewWorkingSetDatabaseLoader(workingSet string, dao db.DAO) WorkingSetLoader { + return &databaseLoader{workingSet: workingSet, dao: dao} +} + +func (l *databaseLoader) ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error) { + workingSet, err := l.dao.GetWorkingSet(ctx, l.workingSet) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return workingset.WorkingSet{}, fmt.Errorf("%w: %s", ErrWorkingSetNotFound, l.workingSet) + } + return workingset.WorkingSet{}, err + } + return workingset.NewFromDb(workingSet), nil +} + +type fileLoader struct { + workingSetFile string + ociService oci.Service +} + +func NewWorkingSetFileLoader(workingSetFile string, ociService oci.Service) WorkingSetLoader { + return &fileLoader{workingSetFile: workingSetFile, ociService: ociService} +} + +func (l *fileLoader) ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error) { + workingSet, err := workingset.ReadFromFile(ctx, l.ociService, l.workingSetFile) + if err != nil { + if os.IsNotExist(err) { + return workingset.WorkingSet{}, fmt.Errorf("%w: %s", ErrWorkingSetNotFound, l.workingSetFile) + } + return workingset.WorkingSet{}, err + } + + if err := workingSet.Validate(); err != nil { + return workingset.WorkingSet{}, fmt.Errorf("invalid profile: %w", err) + } + + return workingSet, nil +} diff --git a/pkg/workingset/import.go b/pkg/workingset/import.go index 261c53c1..e8dff282 100644 --- a/pkg/workingset/import.go +++ b/pkg/workingset/import.go @@ -16,33 +16,9 @@ import ( ) func Import(ctx context.Context, dao db.DAO, ociService oci.Service, filename string) error { - workingSetBuf, err := os.ReadFile(filename) + workingSet, err := ReadFromFile(ctx, ociService, filename) if err != nil { - return fmt.Errorf("failed to read profile file: %w", err) - } - - var workingSet WorkingSet - if strings.HasSuffix(strings.ToLower(filename), ".yaml") { - if err := yaml.Unmarshal(workingSetBuf, &workingSet); err != nil { - return fmt.Errorf("failed to unmarshal profile: %w", err) - } - } else if strings.HasSuffix(strings.ToLower(filename), ".json") { - if err := json.Unmarshal(workingSetBuf, &workingSet); err != nil { - return fmt.Errorf("failed to unmarshal profile: %w", err) - } - } else { - return fmt.Errorf("unsupported file extension: %s, must be .yaml or .json", filename) - } - - // Resolve snapshots for each server before saving - for i := range len(workingSet.Servers) { - if workingSet.Servers[i].Snapshot == nil { - snapshot, err := ResolveSnapshot(ctx, ociService, workingSet.Servers[i]) - if err != nil { - return fmt.Errorf("failed to resolve snapshot for server[%d]: %w", i, err) - } - workingSet.Servers[i].Snapshot = snapshot - } + return err } if err := workingSet.Validate(); err != nil { @@ -72,3 +48,36 @@ func Import(ctx context.Context, dao db.DAO, ociService oci.Service, filename st return nil } + +func ReadFromFile(ctx context.Context, ociService oci.Service, filename string) (WorkingSet, error) { + workingSetBuf, err := os.ReadFile(filename) + if err != nil { + return WorkingSet{}, fmt.Errorf("failed to read profile file: %w", err) + } + + var workingSet WorkingSet + if strings.HasSuffix(strings.ToLower(filename), ".yaml") { + if err := yaml.Unmarshal(workingSetBuf, &workingSet); err != nil { + return WorkingSet{}, fmt.Errorf("failed to unmarshal profile: %w", err) + } + } else if strings.HasSuffix(strings.ToLower(filename), ".json") { + if err := json.Unmarshal(workingSetBuf, &workingSet); err != nil { + return WorkingSet{}, fmt.Errorf("failed to unmarshal profile: %w", err) + } + } else { + return WorkingSet{}, fmt.Errorf("unsupported file extension: %s, must be .yaml or .json", filename) + } + + // Resolve snapshots for each server before saving + for i := range len(workingSet.Servers) { + if workingSet.Servers[i].Snapshot == nil { + snapshot, err := ResolveSnapshot(ctx, ociService, workingSet.Servers[i]) + if err != nil { + return WorkingSet{}, fmt.Errorf("failed to resolve snapshot for server[%d]: %w", i, err) + } + workingSet.Servers[i].Snapshot = snapshot + } + } + + return workingSet, nil +} diff --git a/pkg/workingset/workingset.go b/pkg/workingset/workingset.go index 036a97e9..cbad301d 100644 --- a/pkg/workingset/workingset.go +++ b/pkg/workingset/workingset.go @@ -47,7 +47,7 @@ type Server struct { Type ServerType `yaml:"type" json:"type" validate:"required,oneof=registry image remote"` Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"` Secrets string `yaml:"secrets,omitempty" json:"secrets,omitempty"` - Tools []string `yaml:"tools" json:"tools"` + Tools ToolList `yaml:"tools,omitempty" json:"tools"` // See IsZero() below // ServerTypeRegistry only Source string `yaml:"source,omitempty" json:"source,omitempty" validate:"required_if=Type registry"` @@ -77,6 +77,14 @@ type ServerSnapshot struct { Server catalog.Server `yaml:"server" json:"server"` } +type ToolList []string + +// Needed for proper YAML encoding with omitempty. YAML defaults IsZero to true when a slice is empty, but we only want it on nil. +// This IsZero() + omitempty matches json behavior without omitempty. +func (tools ToolList) IsZero() bool { + return tools == nil +} + func NewFromDb(dbSet *db.WorkingSet) WorkingSet { servers := make([]Server, len(dbSet.Servers)) for i, server := range dbSet.Servers { @@ -496,7 +504,7 @@ func mapCatalogServersToWorkingSetServers(dbServers []db.CatalogServer, secrets for i, server := range dbServers { servers[i] = Server{ Type: ServerType(server.ServerType), - Tools: server.Tools, + Tools: ToolList(server.Tools), Config: map[string]any{}, Source: server.Source, Image: server.Image, diff --git a/pkg/workingset/workingset_test.go b/pkg/workingset/workingset_test.go index 071dc270..a95c78e6 100644 --- a/pkg/workingset/workingset_test.go +++ b/pkg/workingset/workingset_test.go @@ -62,7 +62,7 @@ func TestNewFromDb(t *testing.T) { assert.Equal(t, ServerTypeRegistry, workingSet.Servers[0].Type) assert.Equal(t, "https://example.com/server", workingSet.Servers[0].Source) assert.Equal(t, map[string]any{"key": "value"}, workingSet.Servers[0].Config) - assert.Equal(t, []string{"tool1", "tool2"}, workingSet.Servers[0].Tools) + assert.Equal(t, ToolList([]string{"tool1", "tool2"}), workingSet.Servers[0].Tools) // Check image server assert.Equal(t, ServerTypeImage, workingSet.Servers[1].Type) @@ -170,7 +170,7 @@ func TestNewFromDbWithRemoteServer(t *testing.T) { // Check remote server assert.Equal(t, ServerTypeRemote, workingSet.Servers[0].Type) assert.Equal(t, "https://mcp.example.com/sse", workingSet.Servers[0].Endpoint) - assert.Equal(t, []string{"tool1", "tool2"}, workingSet.Servers[0].Tools) + assert.Equal(t, ToolList([]string{"tool1", "tool2"}), workingSet.Servers[0].Tools) } func TestWorkingSetToDbWithRemoteServer(t *testing.T) {