diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 63174c9e9..d6af14abb 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -1789,7 +1789,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // Add the GraphQL-Features header for the agent assignment API // The header will be read by the HTTP transport if it's configured to do so - ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + ctxWithFeatures := WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") if err := client.Mutate( ctxWithFeatures, @@ -1917,8 +1917,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser // graphQLFeaturesKey is a context key for GraphQL feature flags type graphQLFeaturesKey struct{} -// withGraphQLFeatures adds GraphQL feature flags to the context -func withGraphQLFeatures(ctx context.Context, features ...string) context.Context { +// WithGraphQLFeatures adds GraphQL feature flags to the context +func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context { return context.WithValue(ctx, graphQLFeaturesKey{}, features) } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 21e78874a..e8bdd4b86 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -3723,3 +3723,35 @@ func Test_ListIssueTypes(t *testing.T) { }) } } + +func TestWithGraphQLFeatures(t *testing.T) { + t.Parallel() + + t.Run("adds features to context", func(t *testing.T) { + ctx := context.Background() + features := []string{"feature1", "feature2"} + + ctx = WithGraphQLFeatures(ctx, features...) + retrievedFeatures := GetGraphQLFeatures(ctx) + + assert.Equal(t, features, retrievedFeatures) + }) + + t.Run("returns nil for context without features", func(t *testing.T) { + ctx := context.Background() + features := GetGraphQLFeatures(ctx) + + assert.Nil(t, features) + }) + + t.Run("can add multiple features", func(t *testing.T) { + ctx := context.Background() + ctx = WithGraphQLFeatures(ctx, "feature1", "feature2", "feature3") + features := GetGraphQLFeatures(ctx) + + assert.Len(t, features, 3) + assert.Contains(t, features, "feature1") + assert.Contains(t, features, "feature2") + assert.Contains(t, features, "feature3") + }) +}