Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions cmd/docker-mcp/commands/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,20 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli, features featur
len(mcpRegistryUrls) > 0 || len(options.OciRef) > 0 ||
(options.SecretsPath != "docker-desktop" && !strings.HasPrefix(options.SecretsPath, "docker-desktop:")) {
// We're in legacy mode, so we can't use the working set feature
if options.WorkingSet != "" {
return fmt.Errorf("cannot use --profile with --servers, --enable-all-servers, --catalog, --additional-catalog, --registry, --additional-registry, --config, --additional-config, --tools-config, --additional-tools-config, --secrets, --oci-ref, --mcp-registry flags")
if options.WorkingSet != "" || options.WorkingSetFile != "" {
return fmt.Errorf("cannot use --profile or --profile-file with --servers, --enable-all-servers, --catalog, --additional-catalog, --registry, --additional-registry, --config, --additional-config, --tools-config, --additional-tools-config, --secrets, --oci-ref, --mcp-registry flags")
}
// Make sure to default the options in legacy mode
setLegacyDefaults(&options)
} else if options.WorkingSet == "" {
} else if options.WorkingSet == "" && options.WorkingSetFile == "" {
// ELSE we're in working set mode,
// so IF no profile specified, use the default profile
options.WorkingSet = "default"
}

if options.WorkingSet != "" && options.WorkingSetFile != "" {
return fmt.Errorf("cannot use both --profile and --profile-file at the same time")
}
}

// Check if OAuth interceptor feature is enabled
Expand Down Expand Up @@ -183,7 +187,8 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli, features featur

runCmd.Flags().StringSliceVar(&options.ServerNames, "servers", nil, "Names of the servers to enable (if non empty, ignore --registry flag)")
if features.IsProfilesFeatureEnabled() {
runCmd.Flags().StringVar(&options.WorkingSet, "profile", "", "Profile ID to use (mutually exclusive with --servers and --enable-all-servers)")
runCmd.Flags().StringVar(&options.WorkingSet, "profile", "", "Profile ID to use (mutually exclusive with --profile-file, --servers and --enable-all-servers)")
runCmd.Flags().StringVar(&options.WorkingSetFile, "profile-file", "", "Profile file to use (mutually exclusive with --profile, --servers and --enable-all-servers)")
}
runCmd.Flags().BoolVar(&enableAllServers, "enable-all-servers", false, "Enable all servers in the catalog (instead of using individual --servers options)")
runCmd.Flags().StringSliceVar(&options.CatalogPath, "catalog", options.CatalogPath, "Paths to docker catalogs (absolute or relative to ~/.docker/mcp/catalogs/)")
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "github.com/docker/mcp-gateway/pkg/catalog"
type Config struct {
Options
WorkingSet string
WorkingSetFile string
ServerNames []string
CatalogPath []string
ConfigPath []string
Expand Down
57 changes: 34 additions & 23 deletions pkg/gateway/configuration_workingset.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"context"
"database/sql"
"errors"
"fmt"
"time"
Expand All @@ -18,58 +17,70 @@ import (
)

type WorkingSetConfiguration struct {
WorkingSet string
ociService oci.Service
docker docker.Client
workingSetID string
workingSetFile string
ociService oci.Service
docker docker.Client
}

func NewWorkingSetConfiguration(workingSet string, ociService oci.Service, docker docker.Client) *WorkingSetConfiguration {
func NewWorkingSetConfiguration(workingSetID string, workingSetFile string, ociService oci.Service, docker docker.Client) *WorkingSetConfiguration {
return &WorkingSetConfiguration{
WorkingSet: workingSet,
ociService: ociService,
docker: docker,
workingSetID: workingSetID,
workingSetFile: workingSetFile,
ociService: ociService,
docker: docker,
}
}

func (c *WorkingSetConfiguration) Read(ctx context.Context) (Configuration, chan Configuration, func() error, error) {
dao, err := db.New()
if err != nil {
return Configuration{}, nil, nil, fmt.Errorf("failed to create database client: %w", err)
}
var loader WorkingSetLoader
cleanup := func() error { return nil }

if c.workingSetID != "" {
dao, err := db.New()
if err != nil {
return Configuration{}, nil, nil, fmt.Errorf("failed to create database client: %w", err)
}
cleanup = func() error { return dao.Close() }

// Do migration from legacy files
migrate.MigrateConfig(ctx, c.docker, dao)
// Do migration from legacy files
migrate.MigrateConfig(ctx, c.docker, dao)

configuration, err := c.readOnce(ctx, dao)
loader = NewWorkingSetDatabaseLoader(c.workingSetID, dao)
} else if c.workingSetFile != "" {
loader = NewWorkingSetFileLoader(c.workingSetFile, c.ociService)
} else {
return Configuration{}, nil, nil, fmt.Errorf("no working set ID or file provided")
}

configuration, err := c.readOnce(ctx, loader)
if err != nil {
return Configuration{}, nil, nil, err
}

// TODO(cody): Stub for now
updates := make(chan Configuration)

return configuration, updates, func() error { return nil }, nil
return configuration, updates, cleanup, nil
}

func (c *WorkingSetConfiguration) readOnce(ctx context.Context, dao db.DAO) (Configuration, error) {
func (c *WorkingSetConfiguration) readOnce(ctx context.Context, loader WorkingSetLoader) (Configuration, error) {
start := time.Now()
log.Log("- Reading profile configuration...")

dbWorkingSet, err := dao.GetWorkingSet(ctx, c.WorkingSet)
workingSet, err := loader.ReadWorkingSet(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, ErrWorkingSetNotFound) {
// Special case for the default profile, we're okay with it not existing
if c.WorkingSet == "default" {
if c.workingSetID == "default" {
log.Log(" - Default profile not found, using empty configuration")
return c.emptyConfiguration()
}
return Configuration{}, fmt.Errorf("profile %s not found", c.WorkingSet)
return Configuration{}, err
}
return Configuration{}, fmt.Errorf("failed to get profile: %w", err)
}

workingSet := workingset.NewFromDb(dbWorkingSet)

if err := workingSet.EnsureSnapshotsResolved(ctx, c.ociService); err != nil {
return Configuration{}, fmt.Errorf("failed to resolve snapshots: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/gateway/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ type Gateway struct {

func NewGateway(config Config, docker docker.Client) *Gateway {
var configurator Configurator
if config.WorkingSet != "" {
configurator = NewWorkingSetConfiguration(config.WorkingSet, oci.NewService(), docker)
if config.WorkingSet != "" || config.WorkingSetFile != "" {
configurator = NewWorkingSetConfiguration(config.WorkingSet, config.WorkingSetFile, oci.NewService(), docker)
} else {
configurator = &FileBasedConfiguration{
ServerNames: config.ServerNames,
Expand Down
64 changes: 64 additions & 0 deletions pkg/gateway/workingset_loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package gateway

import (
"context"
"database/sql"
"errors"
"fmt"
"os"

"github.com/docker/mcp-gateway/pkg/db"
"github.com/docker/mcp-gateway/pkg/oci"
"github.com/docker/mcp-gateway/pkg/workingset"
)

var ErrWorkingSetNotFound = errors.New("profile not found")

type WorkingSetLoader interface {
ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error)
}

type databaseLoader struct {
workingSet string
dao db.DAO
}

func NewWorkingSetDatabaseLoader(workingSet string, dao db.DAO) WorkingSetLoader {
return &databaseLoader{workingSet: workingSet, dao: dao}
}

func (l *databaseLoader) ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error) {
workingSet, err := l.dao.GetWorkingSet(ctx, l.workingSet)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return workingset.WorkingSet{}, fmt.Errorf("%w: %s", ErrWorkingSetNotFound, l.workingSet)
}
return workingset.WorkingSet{}, err
}
return workingset.NewFromDb(workingSet), nil
}

type fileLoader struct {
workingSetFile string
ociService oci.Service
}

func NewWorkingSetFileLoader(workingSetFile string, ociService oci.Service) WorkingSetLoader {
return &fileLoader{workingSetFile: workingSetFile, ociService: ociService}
}

func (l *fileLoader) ReadWorkingSet(ctx context.Context) (workingset.WorkingSet, error) {
workingSet, err := workingset.ReadFromFile(ctx, l.ociService, l.workingSetFile)
if err != nil {
if os.IsNotExist(err) {
return workingset.WorkingSet{}, fmt.Errorf("%w: %s", ErrWorkingSetNotFound, l.workingSetFile)
}
return workingset.WorkingSet{}, err
}

if err := workingSet.Validate(); err != nil {
return workingset.WorkingSet{}, fmt.Errorf("invalid profile: %w", err)
}

return workingSet, nil
}
61 changes: 35 additions & 26 deletions pkg/workingset/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,9 @@ import (
)

func Import(ctx context.Context, dao db.DAO, ociService oci.Service, filename string) error {
workingSetBuf, err := os.ReadFile(filename)
workingSet, err := ReadFromFile(ctx, ociService, filename)
if err != nil {
return fmt.Errorf("failed to read profile file: %w", err)
}

var workingSet WorkingSet
if strings.HasSuffix(strings.ToLower(filename), ".yaml") {
if err := yaml.Unmarshal(workingSetBuf, &workingSet); err != nil {
return fmt.Errorf("failed to unmarshal profile: %w", err)
}
} else if strings.HasSuffix(strings.ToLower(filename), ".json") {
if err := json.Unmarshal(workingSetBuf, &workingSet); err != nil {
return fmt.Errorf("failed to unmarshal profile: %w", err)
}
} else {
return fmt.Errorf("unsupported file extension: %s, must be .yaml or .json", filename)
}

// Resolve snapshots for each server before saving
for i := range len(workingSet.Servers) {
if workingSet.Servers[i].Snapshot == nil {
snapshot, err := ResolveSnapshot(ctx, ociService, workingSet.Servers[i])
if err != nil {
return fmt.Errorf("failed to resolve snapshot for server[%d]: %w", i, err)
}
workingSet.Servers[i].Snapshot = snapshot
}
return err
}

if err := workingSet.Validate(); err != nil {
Expand Down Expand Up @@ -72,3 +48,36 @@ func Import(ctx context.Context, dao db.DAO, ociService oci.Service, filename st

return nil
}

func ReadFromFile(ctx context.Context, ociService oci.Service, filename string) (WorkingSet, error) {
workingSetBuf, err := os.ReadFile(filename)
if err != nil {
return WorkingSet{}, fmt.Errorf("failed to read profile file: %w", err)
}

var workingSet WorkingSet
if strings.HasSuffix(strings.ToLower(filename), ".yaml") {
if err := yaml.Unmarshal(workingSetBuf, &workingSet); err != nil {
return WorkingSet{}, fmt.Errorf("failed to unmarshal profile: %w", err)
}
} else if strings.HasSuffix(strings.ToLower(filename), ".json") {
if err := json.Unmarshal(workingSetBuf, &workingSet); err != nil {
return WorkingSet{}, fmt.Errorf("failed to unmarshal profile: %w", err)
}
} else {
return WorkingSet{}, fmt.Errorf("unsupported file extension: %s, must be .yaml or .json", filename)
}

// Resolve snapshots for each server before saving
for i := range len(workingSet.Servers) {
if workingSet.Servers[i].Snapshot == nil {
snapshot, err := ResolveSnapshot(ctx, ociService, workingSet.Servers[i])
if err != nil {
return WorkingSet{}, fmt.Errorf("failed to resolve snapshot for server[%d]: %w", i, err)
}
workingSet.Servers[i].Snapshot = snapshot
}
}

return workingSet, nil
}
12 changes: 10 additions & 2 deletions pkg/workingset/workingset.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Server struct {
Type ServerType `yaml:"type" json:"type" validate:"required,oneof=registry image remote"`
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
Secrets string `yaml:"secrets,omitempty" json:"secrets,omitempty"`
Tools []string `yaml:"tools" json:"tools"`
Tools ToolList `yaml:"tools,omitempty" json:"tools"` // See IsZero() below

// ServerTypeRegistry only
Source string `yaml:"source,omitempty" json:"source,omitempty" validate:"required_if=Type registry"`
Expand Down Expand Up @@ -77,6 +77,14 @@ type ServerSnapshot struct {
Server catalog.Server `yaml:"server" json:"server"`
}

type ToolList []string

// Needed for proper YAML encoding with omitempty. YAML defaults IsZero to true when a slice is empty, but we only want it on nil.
// This IsZero() + omitempty matches json behavior without omitempty.
func (tools ToolList) IsZero() bool {
return tools == nil
}

func NewFromDb(dbSet *db.WorkingSet) WorkingSet {
servers := make([]Server, len(dbSet.Servers))
for i, server := range dbSet.Servers {
Expand Down Expand Up @@ -496,7 +504,7 @@ func mapCatalogServersToWorkingSetServers(dbServers []db.CatalogServer, secrets
for i, server := range dbServers {
servers[i] = Server{
Type: ServerType(server.ServerType),
Tools: server.Tools,
Tools: ToolList(server.Tools),
Config: map[string]any{},
Source: server.Source,
Image: server.Image,
Expand Down
4 changes: 2 additions & 2 deletions pkg/workingset/workingset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestNewFromDb(t *testing.T) {
assert.Equal(t, ServerTypeRegistry, workingSet.Servers[0].Type)
assert.Equal(t, "https://example.com/server", workingSet.Servers[0].Source)
assert.Equal(t, map[string]any{"key": "value"}, workingSet.Servers[0].Config)
assert.Equal(t, []string{"tool1", "tool2"}, workingSet.Servers[0].Tools)
assert.Equal(t, ToolList([]string{"tool1", "tool2"}), workingSet.Servers[0].Tools)

// Check image server
assert.Equal(t, ServerTypeImage, workingSet.Servers[1].Type)
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestNewFromDbWithRemoteServer(t *testing.T) {
// Check remote server
assert.Equal(t, ServerTypeRemote, workingSet.Servers[0].Type)
assert.Equal(t, "https://mcp.example.com/sse", workingSet.Servers[0].Endpoint)
assert.Equal(t, []string{"tool1", "tool2"}, workingSet.Servers[0].Tools)
assert.Equal(t, ToolList([]string{"tool1", "tool2"}), workingSet.Servers[0].Tools)
}

func TestWorkingSetToDbWithRemoteServer(t *testing.T) {
Expand Down
Loading