diff --git a/internal/ghmcp/server_test.go b/internal/ghmcp/server_test.go index 04c0989d4..26c780c3e 100644 --- a/internal/ghmcp/server_test.go +++ b/internal/ghmcp/server_test.go @@ -1,8 +1,12 @@ package ghmcp import ( + "context" + "net/http" + "net/http/httptest" "testing" + "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/translations" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -110,3 +114,127 @@ func TestResolveEnabledToolsets(t *testing.T) { }) } } + +// TestBearerAuthTransport_AddsGraphQLFeaturesHeader verifies that the bearerAuthTransport +// properly reads GraphQL features from context and adds them as a header. +func TestBearerAuthTransport_AddsGraphQLFeaturesHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + features []string + expectHeader bool + expectedHeaderValue string + }{ + { + name: "single feature", + features: []string{"issues_copilot_assignment_api_support"}, + expectHeader: true, + expectedHeaderValue: "issues_copilot_assignment_api_support", + }, + { + name: "multiple features", + features: []string{"feature1", "feature2", "feature3"}, + expectHeader: true, + expectedHeaderValue: "feature1, feature2, feature3", + }, + { + name: "no features", + features: []string{}, + expectHeader: false, + }, + { + name: "nil features", + features: nil, + expectHeader: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a test server that records the request + var capturedRequest *http.Request + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedRequest = r + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + // Create the transport chain + transport := &bearerAuthTransport{ + transport: http.DefaultTransport, + token: "test-token", + } + + // Create an HTTP client with the transport + client := &http.Client{Transport: transport} + + // Create a context with GraphQL features + ctx := context.Background() + if tc.features != nil { + ctx = github.WithGraphQLFeatures(ctx, tc.features...) + } + + // Make a request with the context + req, err := http.NewRequestWithContext(ctx, "POST", testServer.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify the Authorization header is set + assert.Equal(t, "Bearer test-token", capturedRequest.Header.Get("Authorization")) + + // Verify the GraphQL-Features header + if tc.expectHeader { + assert.Equal(t, tc.expectedHeaderValue, capturedRequest.Header.Get("GraphQL-Features")) + } else { + assert.Empty(t, capturedRequest.Header.Get("GraphQL-Features")) + } + }) + } +} + +// TestUserAgentTransport_PreservesGraphQLFeatures verifies that the userAgentTransport +// doesn't interfere with GraphQL features set by bearerAuthTransport. +func TestUserAgentTransport_PreservesGraphQLFeatures(t *testing.T) { + t.Parallel() + + // Create a test server that records the request + var capturedRequest *http.Request + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedRequest = r + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + // Create the transport chain (same as in production) + // userAgentTransport -> bearerAuthTransport -> http.DefaultTransport + transport := &userAgentTransport{ + transport: &bearerAuthTransport{ + transport: http.DefaultTransport, + token: "test-token", + }, + agent: "test-agent/1.0.0", + } + + // Create an HTTP client with the transport chain + client := &http.Client{Transport: transport} + + // Create a context with GraphQL features + ctx := github.WithGraphQLFeatures(context.Background(), "issues_copilot_assignment_api_support") + + // Make a request with the context + req, err := http.NewRequestWithContext(ctx, "POST", testServer.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify all headers are set correctly + assert.Equal(t, "test-agent/1.0.0", capturedRequest.Header.Get("User-Agent")) + assert.Equal(t, "Bearer test-token", capturedRequest.Header.Get("Authorization")) + assert.Equal(t, "issues_copilot_assignment_api_support", capturedRequest.Header.Get("GraphQL-Features")) +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 63174c9e9..d6af14abb 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -1789,7 +1789,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // Add the GraphQL-Features header for the agent assignment API // The header will be read by the HTTP transport if it's configured to do so - ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + ctxWithFeatures := WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") if err := client.Mutate( ctxWithFeatures, @@ -1917,8 +1917,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser // graphQLFeaturesKey is a context key for GraphQL feature flags type graphQLFeaturesKey struct{} -// withGraphQLFeatures adds GraphQL feature flags to the context -func withGraphQLFeatures(ctx context.Context, features ...string) context.Context { +// WithGraphQLFeatures adds GraphQL feature flags to the context +func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context { return context.WithValue(ctx, graphQLFeaturesKey{}, features) }