diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 9b41e368..9cafb8df 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -14,6 +14,7 @@ import ( "net/url" "slices" "strings" + "sync" "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -128,6 +129,9 @@ type AuthorizationCodeHandler struct { // tokenSource is the token source to use for authorization. tokenSource oauth2.TokenSource + + mu sync.Mutex + requestedScopes []string } var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) @@ -288,6 +292,14 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ scps = append(scps, "offline_access") } + // Accumulate scopes: union previously requested scopes with the newly + // challenged scopes so that step-up authorization does not lose + // permissions granted in earlier rounds (SEP-2350). + h.mu.Lock() + scps = unionScopes(h.requestedScopes, scps) + h.requestedScopes = scps + h.mu.Unlock() + cfg := &oauth2.Config{ ClientID: resolvedClientConfig.clientID, ClientSecret: resolvedClientConfig.clientSecret, @@ -343,6 +355,25 @@ func errorFromChallenges(cs []oauthex.Challenge) string { return "" } +// unionScopes returns the union of existing and challenged scopes, +// preserving order (existing first, then new challenged scopes). +func unionScopes(existing, challenged []string) []string { + if len(existing) == 0 { + return challenged + } + if len(challenged) == 0 { + return existing + } + result := make([]string, len(existing), len(existing)+len(challenged)) + copy(result, existing) + for _, s := range challenged { + if !slices.Contains(result, s) { + result = append(result, s) + } + } + return result +} + // getProtectedResourceMetadata returns the protected resource metadata. // If no metadata was found or the fetched metadata fails security checks, // it returns an error. diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index d30cbf36..419a8389 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -113,6 +113,95 @@ func TestAuthorize(t *testing.T) { } } +func TestAuthorize_ScopeAccumulation(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test_client_id": { + Secret: "test_client_secret", + RedirectURIs: []string{"http://localhost:12345/callback"}, + }, + }, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + + var capturedAuthURLs []string + + handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:12345/callback", + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "test_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "test_client_secret", + }, + }, + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + capturedAuthURLs = append(capturedAuthURLs, args.URL) + return nil, fmt.Errorf("stop after capturing URL") + }, + }) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + // First authorization: 401 with scope="read" + req := httptest.NewRequest(http.MethodGet, resourceURL, nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + } + resp.Header.Set("WWW-Authenticate", + fmt.Sprintf(`Bearer scope="read", resource_metadata="%s/.well-known/oauth-protected-resource/resource"`, resourceServer.URL)) + err = handler.Authorize(context.Background(), req, resp) + if err == nil || !strings.Contains(err.Error(), "stop after capturing URL") { + t.Fatalf("First Authorize expected error containing 'stop after capturing URL', got: %v", err) + } + + // Verify first auth URL requested only "read". + firstURL, err := url.Parse(capturedAuthURLs[0]) + if err != nil { + t.Fatalf("Failed to parse first auth URL: %v", err) + } + if got := firstURL.Query().Get("scope"); got != "read" { + t.Errorf("First auth scope = %q, want %q", got, "read") + } + + // Second authorization: 403 insufficient_scope with scope="write" + req2 := httptest.NewRequest(http.MethodGet, resourceURL, nil) + resp2 := &http.Response{ + StatusCode: http.StatusForbidden, + Header: make(http.Header), + Body: http.NoBody, + } + resp2.Header.Set("WWW-Authenticate", + fmt.Sprintf(`Bearer error="insufficient_scope", scope="write", resource_metadata="%s/.well-known/oauth-protected-resource/resource"`, resourceServer.URL)) + err = handler.Authorize(context.Background(), req2, resp2) + if err == nil || !strings.Contains(err.Error(), "stop after capturing URL") { + t.Fatalf("Second Authorize expected error containing 'stop after capturing URL', got: %v", err) + } + + // Verify second auth URL accumulated both scopes. + secondURL, err := url.Parse(capturedAuthURLs[1]) + if err != nil { + t.Fatalf("Failed to parse second auth URL: %v", err) + } + if got := secondURL.Query().Get("scope"); got != "read write" { + t.Errorf("Second auth scope = %q, want %q", got, "read write") + } +} + func TestAuthorize_ForbiddenUnhandledError(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://example.com/resource", nil) resp := &http.Response{ @@ -740,6 +829,67 @@ func TestApplicationTypeInference(t *testing.T) { } } +func TestUnionScopes(t *testing.T) { + tests := []struct { + name string + existing []string + challenged []string + want []string + }{ + { + name: "both empty", + existing: nil, + challenged: nil, + want: nil, + }, + { + name: "existing only", + existing: []string{"read"}, + challenged: nil, + want: []string{"read"}, + }, + { + name: "challenged only", + existing: nil, + challenged: []string{"write"}, + want: []string{"write"}, + }, + { + name: "disjoint scopes", + existing: []string{"read"}, + challenged: []string{"write"}, + want: []string{"read", "write"}, + }, + { + name: "overlapping scopes", + existing: []string{"read", "write"}, + challenged: []string{"write", "admin"}, + want: []string{"read", "write", "admin"}, + }, + { + name: "identical scopes", + existing: []string{"read", "write"}, + challenged: []string{"read", "write"}, + want: []string{"read", "write"}, + }, + { + name: "preserves order", + existing: []string{"b", "a"}, + challenged: []string{"c", "a"}, + want: []string{"b", "a", "c"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unionScopes(tt.existing, tt.challenged) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("unionScopes() mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestAuthorize_OfflineAccessScope(t *testing.T) { tests := []struct { name string diff --git a/auth/extauth/client_credentials.go b/auth/extauth/client_credentials.go index fefa29b9..34bcd627 100644 --- a/auth/extauth/client_credentials.go +++ b/auth/extauth/client_credentials.go @@ -12,6 +12,7 @@ import ( "net/url" "slices" "strings" + "sync" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -47,6 +48,9 @@ type ClientCredentialsHandlerConfig struct { type ClientCredentialsHandler struct { config *ClientCredentialsHandlerConfig tokenSource oauth2.TokenSource + + mu sync.Mutex + requestedScopes []string } // Compile-time check that ClientCredentialsHandler implements auth.OAuthHandler. @@ -127,6 +131,14 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ scopes = prm.ScopesSupported } + // Accumulate scopes: union previously requested scopes with the newly + // challenged scopes so that step-up authorization does not lose + // permissions granted in earlier rounds (SEP-2350). + h.mu.Lock() + scopes = unionScopes(h.requestedScopes, scopes) + h.requestedScopes = scopes + h.mu.Unlock() + // Step 3: Exchange client credentials for an access token. creds := h.config.Credentials cfg := &clientcredentials.Config{ @@ -229,6 +241,25 @@ func scopesFromChallenges(cs []oauthex.Challenge) []string { return nil } +// unionScopes returns the union of existing and challenged scopes, +// preserving order (existing first, then new challenged scopes). +func unionScopes(existing, challenged []string) []string { + if len(existing) == 0 { + return challenged + } + if len(challenged) == 0 { + return existing + } + result := make([]string, len(existing), len(existing)+len(challenged)) + copy(result, existing) + for _, s := range challenged { + if !slices.Contains(result, s) { + result = append(result, s) + } + } + return result +} + // selectTokenAuthMethod selects the preferred token endpoint auth method based on // the authorization server's supported methods. Prefers client_secret_post over // client_secret_basic per the OAuth 2.1 draft. diff --git a/auth/extauth/client_credentials_test.go b/auth/extauth/client_credentials_test.go index 91ae7c04..fe46c48e 100644 --- a/auth/extauth/client_credentials_test.go +++ b/auth/extauth/client_credentials_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -255,6 +256,135 @@ func TestClientCredentialsHandler_Authorize(t *testing.T) { }) } +func TestClientCredentialsHandler_ScopeAccumulation(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + MetadataEndpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test-client": {Secret: "test-secret"}, + }, + }, + ClientCredentialsConfig: &oauthtest.ClientCredentialsConfig{ + Enabled: true, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", auth.ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + + handler, err := NewClientCredentialsHandler(validClientCredentialsConfig()) + if err != nil { + t.Fatal(err) + } + + // First authorization: 401 with scope="read" + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + } + resp.Header.Set("WWW-Authenticate", `Bearer scope="read"`) + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("First Authorize failed: %v", err) + } + + // Verify handler tracked the requested scopes. + handler.mu.Lock() + firstScopes := append([]string{}, handler.requestedScopes...) + handler.mu.Unlock() + wantFirst := []string{"read"} + if diff := cmp.Diff(wantFirst, firstScopes); diff != "" { + t.Errorf("After first Authorize, requestedScopes mismatch (-want +got):\n%s", diff) + } + + // Second authorization: 401 with scope="write" (simulating step-up) + resp2 := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + } + resp2.Header.Set("WWW-Authenticate", `Bearer scope="write"`) + req2 := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req2, resp2); err != nil { + t.Fatalf("Second Authorize failed: %v", err) + } + + // Verify handler accumulated both scopes. + handler.mu.Lock() + secondScopes := append([]string{}, handler.requestedScopes...) + handler.mu.Unlock() + wantSecond := []string{"read", "write"} + if diff := cmp.Diff(wantSecond, secondScopes); diff != "" { + t.Errorf("After second Authorize, requestedScopes mismatch (-want +got):\n%s", diff) + } +} + +func TestUnionScopes(t *testing.T) { + tests := []struct { + name string + existing []string + challenged []string + want []string + }{ + { + name: "both empty", + existing: nil, + challenged: nil, + want: nil, + }, + { + name: "existing only", + existing: []string{"read"}, + challenged: nil, + want: []string{"read"}, + }, + { + name: "challenged only", + existing: nil, + challenged: []string{"write"}, + want: []string{"write"}, + }, + { + name: "disjoint scopes", + existing: []string{"read"}, + challenged: []string{"write"}, + want: []string{"read", "write"}, + }, + { + name: "overlapping scopes", + existing: []string{"read", "write"}, + challenged: []string{"write", "admin"}, + want: []string{"read", "write", "admin"}, + }, + { + name: "identical scopes", + existing: []string{"read", "write"}, + challenged: []string{"read", "write"}, + want: []string{"read", "write"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unionScopes(tt.existing, tt.challenged) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("unionScopes() mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestSelectTokenAuthMethod(t *testing.T) { tests := []struct { name string