Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
return fmt.Errorf("failed to validate optimizer config: %w", err)
}

sessionFactory, err := createSessionFactory(outgoingRegistry, agg)
sessionFactory, err := createSessionFactory(&env.OSReader{}, outgoingRegistry, agg)
if err != nil {
return err
}
Expand Down Expand Up @@ -448,6 +448,7 @@ func discoverBackends(
// - If running in Kubernetes without secret: returns error (production safety requirement).
// - Otherwise: logs warning and creates factory with default insecure secret.
func createSessionFactory(
envReader env.Reader,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider passing data instead of injecting env.Reader. createSessionFactory only needs two pieces of information from the environment: the HMAC secret string and whether we're running in Kubernetes. Passing those as plain values makes the function's contract explicit in its signature and eliminates mock boilerplate in tests:

func createSessionFactory(
    hmacSecret string,
    isKubernetes bool,
    outgoingRegistry vmcpauth.OutgoingAuthRegistry,
    agg aggregator.Aggregator,
) (vmcpsession.MultiSessionFactory, error) {

The caller (Serve) already has the env reader — it can resolve Getenv("VMCP_SESSION_HMAC_SECRET") and IsKubernetesRuntimeWithEnv(envReader) before calling this function.

This avoids tests needing to know which env vars the function checks internally (e.g., TOOLHIVE_RUNTIME, KUBERNETES_SERVICE_HOST), which makes them fragile if the runtime detection logic ever adds another variable. Tests become trivial: createSessionFactory("my-secret", false, registry, agg).

outgoingRegistry vmcpauth.OutgoingAuthRegistry,
agg aggregator.Aggregator,
) (vmcpsession.MultiSessionFactory, error) {
Expand All @@ -461,7 +462,7 @@ func createSessionFactory(
opts = append(opts, vmcpsession.WithAggregator(agg))
}

hmacSecret := os.Getenv(envKey)
hmacSecret := envReader.Getenv(envKey)

if hmacSecret != "" {
if secretLen := len(hmacSecret); secretLen < minRecommendedSecretLen {
Expand All @@ -478,7 +479,7 @@ func createSessionFactory(
}

// No secret provided — fail fast in Kubernetes (production environment).
if runtime.IsKubernetesRuntime() {
if runtime.IsKubernetesRuntimeWithEnv(envReader) {
return nil, fmt.Errorf(
"VMCP_SESSION_HMAC_SECRET environment variable is required when running in Kubernetes. " +
"Generate a secure secret with: openssl rand -base64 32",
Expand Down
210 changes: 210 additions & 0 deletions pkg/vmcp/cli/serve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package cli

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gopkg.in/yaml.v3"

envmocks "github.com/stacklok/toolhive-core/env/mocks"
authserverconfig "github.com/stacklok/toolhive/pkg/authserver"
aggregatormocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks"
clientmocks "github.com/stacklok/toolhive/pkg/vmcp/client/mocks"
)

// TestLoadAndValidateConfig covers all config-loading paths.
func TestLoadAndValidateConfig(t *testing.T) {
t.Parallel()

tests := []struct {
name string
content string
wantErr bool
errContains string
}{
{
name: "valid config",
content: validConfigYAML,
wantErr: false,
},
{
name: "non-existent file",
content: "", // file will not be created
wantErr: true,
errContains: "configuration loading failed",
},
{
name: "malformed YAML",
content: ":::invalid yaml:::",
wantErr: true,
errContains: "configuration loading failed",
},
{
name: "fails semantic validation — missing groupRef",
content: `
name: test-vmcp
incomingAuth:
type: anonymous
outgoingAuth:
source: inline
aggregation:
conflictResolution: prefix
`,
wantErr: true,
errContains: "validation failed",
Comment on lines +59 to +61
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "fails semantic validation — missing groupRef" case isn’t isolating the missing groupRef failure: this YAML also omits aggregation.conflictResolutionConfig, which is required and will cause validation to fail even if groupRef were present. To make the test actually cover the intended rule (and be more robust), include the required aggregation config in this fixture and assert the wrapped error contains the underlying validation message (e.g. "group reference is required").

Suggested change
`,
wantErr: true,
errContains: "validation failed",
conflictResolutionConfig:
fieldNamePrefix:
separator: "-"
`,
wantErr: true,
errContains: "group reference is required",

Copilot uses AI. Check for mistakes.
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "vmcp.yaml")
if tc.content != "" {
require.NoError(t, os.WriteFile(path, []byte(tc.content), 0o600))
}

cfg, err := loadAndValidateConfig(path)
if tc.wantErr {
require.Error(t, err)
require.ErrorContains(t, err, tc.errContains)
require.Nil(t, cfg)
} else {
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, "test-group", cfg.Group)
}
})
}
}

// TestLoadAuthServerConfig covers all auth-server-config side-loading paths.
// (Additional cases live in auth_server_config_test.go, moved from cmd/vmcp/app.)
func TestLoadAuthServerConfig_NestedDir(t *testing.T) {
t.Parallel()

// Config lives in a subdirectory; sibling authserver-config.yaml must be found correctly.
dir := t.TempDir()
subdir := filepath.Join(dir, "sub", "dir")
require.NoError(t, os.MkdirAll(subdir, 0o750))
configPath := filepath.Join(subdir, "vmcp-config.yaml")

want := &authserverconfig.RunConfig{
Issuer: "https://nested.example.com",
SchemaVersion: "1",
}
data, err := yaml.Marshal(want)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(subdir, "authserver-config.yaml"), data, 0o600))

rc, err := loadAuthServerConfig(configPath)
require.NoError(t, err)
require.NotNil(t, rc)
assert.Equal(t, "https://nested.example.com", rc.Issuer)
}

// TestDiscoverBackends_StaticMode exercises the static-backend path without
// needing a live Kubernetes API.
func TestDiscoverBackends_StaticMode(t *testing.T) {
t.Parallel()

// Build a minimal config with one static backend.
dir := t.TempDir()
path := filepath.Join(dir, "vmcp.yaml")
require.NoError(t, os.WriteFile(path, []byte(`
name: test-vmcp
groupRef: test-group

incomingAuth:
type: anonymous

outgoingAuth:
source: inline
default:
type: unauthenticated

aggregation:
conflictResolution: prefix
conflictResolutionConfig:
prefixFormat: "{workload}_"

backends:
- name: backend-one
url: http://127.0.0.1:9001/sse
transport: sse
`), 0o600))

cfg, err := loadAndValidateConfig(path)
require.NoError(t, err)
require.Len(t, cfg.Backends, 1)

backends, client, registry, err := discoverBackends(t.Context(), cfg)
require.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, registry)
// Static mode: one backend discovered.
assert.Len(t, backends, 1)
}

func newSessionFactoryMocks(t *testing.T) (*envmocks.MockReader, *clientmocks.MockOutgoingAuthRegistry, *aggregatormocks.MockAggregator) {
t.Helper()
ctrl := gomock.NewController(t)
return envmocks.NewMockReader(ctrl), clientmocks.NewMockOutgoingAuthRegistry(ctrl), aggregatormocks.NewMockAggregator(ctrl)
}

func TestCreateSessionFactory_WithHMACSecret(t *testing.T) {
t.Parallel()
envReader, registry, agg := newSessionFactoryMocks(t)
envReader.EXPECT().Getenv("VMCP_SESSION_HMAC_SECRET").Return("a-sufficiently-long-hmac-secret-value-32b")
factory, err := createSessionFactory(envReader, registry, agg)
require.NoError(t, err)
require.NotNil(t, factory)
}

func TestCreateSessionFactory_HMACSecretExactly32Bytes(t *testing.T) {
t.Parallel()
envReader, registry, agg := newSessionFactoryMocks(t)
envReader.EXPECT().Getenv("VMCP_SESSION_HMAC_SECRET").Return("12345678901234567890123456789012")
factory, err := createSessionFactory(envReader, registry, agg)
require.NoError(t, err)
require.NotNil(t, factory)
}

func TestCreateSessionFactory_ShortHMACSecret(t *testing.T) {
t.Parallel()
envReader, registry, agg := newSessionFactoryMocks(t)
envReader.EXPECT().Getenv("VMCP_SESSION_HMAC_SECRET").Return("short")
factory, err := createSessionFactory(envReader, registry, agg)
require.NoError(t, err)
require.NotNil(t, factory)
}

func TestCreateSessionFactory_NoSecretNonKubernetes(t *testing.T) {
t.Parallel()
envReader, registry, agg := newSessionFactoryMocks(t)
envReader.EXPECT().Getenv("VMCP_SESSION_HMAC_SECRET").Return("")
envReader.EXPECT().Getenv("TOOLHIVE_RUNTIME").Return("")
envReader.EXPECT().Getenv("KUBERNETES_SERVICE_HOST").Return("")
factory, err := createSessionFactory(envReader, registry, agg)
require.NoError(t, err)
require.NotNil(t, factory)
}

func TestCreateSessionFactory_NoSecretKubernetes(t *testing.T) {
t.Parallel()
envReader, registry, agg := newSessionFactoryMocks(t)
envReader.EXPECT().Getenv("VMCP_SESSION_HMAC_SECRET").Return("")
envReader.EXPECT().Getenv("TOOLHIVE_RUNTIME").Return("")
envReader.EXPECT().Getenv("KUBERNETES_SERVICE_HOST").Return("10.0.0.1")
factory, err := createSessionFactory(envReader, registry, agg)
require.Error(t, err)
require.ErrorContains(t, err, "VMCP_SESSION_HMAC_SECRET environment variable is required")
require.Nil(t, factory)
}
115 changes: 115 additions & 0 deletions pkg/vmcp/cli/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package cli

import (
"context"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

const validConfigYAML = `
name: test-vmcp
groupRef: test-group

incomingAuth:
type: anonymous

outgoingAuth:
source: inline
default:
type: unauthenticated

aggregation:
conflictResolution: prefix
conflictResolutionConfig:
prefixFormat: "{workload}_"
`

func TestValidate(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setup func(t *testing.T) ValidateConfig
wantErr bool
errContains string
}{
{
name: "missing config path",
setup: func(_ *testing.T) ValidateConfig {
return ValidateConfig{}
},
wantErr: true,
errContains: "no configuration file specified",
},
{
name: "valid config file",
setup: func(t *testing.T) ValidateConfig {
t.Helper()
path := filepath.Join(t.TempDir(), "vmcp.yaml")
require.NoError(t, os.WriteFile(path, []byte(validConfigYAML), 0o600))
return ValidateConfig{ConfigPath: path}
},
wantErr: false,
},
{
name: "non-existent file",
setup: func(t *testing.T) ValidateConfig {
t.Helper()
return ValidateConfig{ConfigPath: filepath.Join(t.TempDir(), "nonexistent.yaml")}
},
wantErr: true,
errContains: "configuration loading failed",
},
{
name: "malformed YAML",
setup: func(t *testing.T) ValidateConfig {
t.Helper()
path := filepath.Join(t.TempDir(), "bad.yaml")
require.NoError(t, os.WriteFile(path, []byte(":::not valid yaml:::"), 0o600))
return ValidateConfig{ConfigPath: path}
},
wantErr: true,
errContains: "configuration loading failed",
},
{
name: "config missing required groupRef",
setup: func(t *testing.T) ValidateConfig {
t.Helper()
path := filepath.Join(t.TempDir(), "invalid.yaml")
// groupRef is required; omitting it must fail validation.
require.NoError(t, os.WriteFile(path, []byte(`
name: test-vmcp
incomingAuth:
type: anonymous
outgoingAuth:
source: inline
aggregation:
conflictResolution: prefix
`), 0o600))
return ValidateConfig{ConfigPath: path}
},
wantErr: true,
errContains: "validation failed",
},
Comment on lines +81 to +99
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is named/commented as "missing required groupRef", but the YAML fixture also omits aggregation.conflictResolutionConfig, which is required and may be the reason validation fails. To ensure the test actually verifies the groupRef requirement, include the required aggregation config in the invalid YAML and assert the error contains the underlying validator message ("group reference is required") rather than only the generic wrapper text.

Copilot uses AI. Check for mistakes.
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := tc.setup(t)
err := Validate(context.Background(), cfg)
if tc.wantErr {
require.Error(t, err)
require.ErrorContains(t, err, tc.errContains)
} else {
require.NoError(t, err)
}
})
}
}
Loading