Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions auth/authorization_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/url"
"slices"
"strings"
"sync"

"github.com/modelcontextprotocol/go-sdk/internal/util"
"github.com/modelcontextprotocol/go-sdk/oauthex"
Expand Down Expand Up @@ -128,6 +129,9 @@ type AuthorizationCodeHandler struct {

// tokenSource is the token source to use for authorization.
tokenSource oauth2.TokenSource

mu sync.Mutex
requestedScopes []string
}

var _ OAuthHandler = (*AuthorizationCodeHandler)(nil)
Expand Down Expand Up @@ -288,6 +292,14 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ
scps = append(scps, "offline_access")
}

// Accumulate scopes: union previously requested scopes with the newly
// challenged scopes so that step-up authorization does not lose
// permissions granted in earlier rounds (SEP-2350).
h.mu.Lock()
scps = unionScopes(h.requestedScopes, scps)
h.requestedScopes = scps
h.mu.Unlock()

cfg := &oauth2.Config{
ClientID: resolvedClientConfig.clientID,
ClientSecret: resolvedClientConfig.clientSecret,
Expand Down Expand Up @@ -343,6 +355,25 @@ func errorFromChallenges(cs []oauthex.Challenge) string {
return ""
}

// unionScopes returns the union of existing and challenged scopes,
// preserving order (existing first, then new challenged scopes).
func unionScopes(existing, challenged []string) []string {
if len(existing) == 0 {
return challenged
}
if len(challenged) == 0 {
return existing
}
result := make([]string, len(existing), len(existing)+len(challenged))
copy(result, existing)
for _, s := range challenged {
if !slices.Contains(result, s) {
result = append(result, s)
}
}
return result
}

// getProtectedResourceMetadata returns the protected resource metadata.
// If no metadata was found or the fetched metadata fails security checks,
// it returns an error.
Expand Down
150 changes: 150 additions & 0 deletions auth/authorization_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,95 @@ func TestAuthorize(t *testing.T) {
}
}

func TestAuthorize_ScopeAccumulation(t *testing.T) {
authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{
RegistrationConfig: &oauthtest.RegistrationConfig{
PreregisteredClients: map[string]oauthtest.ClientInfo{
"test_client_id": {
Secret: "test_client_secret",
RedirectURIs: []string{"http://localhost:12345/callback"},
},
},
},
})
authServer.Start(t)

resourceMux := http.NewServeMux()
resourceServer := httptest.NewServer(resourceMux)
t.Cleanup(resourceServer.Close)
resourceURL := resourceServer.URL + "/resource"

resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
AuthorizationServers: []string{authServer.URL()},
}))

var capturedAuthURLs []string

handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{
RedirectURL: "http://localhost:12345/callback",
PreregisteredClient: &oauthex.ClientCredentials{
ClientID: "test_client_id",
ClientSecretAuth: &oauthex.ClientSecretAuth{
ClientSecret: "test_client_secret",
},
},
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
capturedAuthURLs = append(capturedAuthURLs, args.URL)
return nil, fmt.Errorf("stop after capturing URL")
},
})
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler failed: %v", err)
}

// First authorization: 401 with scope="read"
req := httptest.NewRequest(http.MethodGet, resourceURL, nil)
resp := &http.Response{
StatusCode: http.StatusUnauthorized,
Header: make(http.Header),
Body: http.NoBody,
}
resp.Header.Set("WWW-Authenticate",
fmt.Sprintf(`Bearer scope="read", resource_metadata="%s/.well-known/oauth-protected-resource/resource"`, resourceServer.URL))
err = handler.Authorize(context.Background(), req, resp)
if err == nil || !strings.Contains(err.Error(), "stop after capturing URL") {
t.Fatalf("First Authorize expected error containing 'stop after capturing URL', got: %v", err)
}

// Verify first auth URL requested only "read".
firstURL, err := url.Parse(capturedAuthURLs[0])
if err != nil {
t.Fatalf("Failed to parse first auth URL: %v", err)
}
if got := firstURL.Query().Get("scope"); got != "read" {
t.Errorf("First auth scope = %q, want %q", got, "read")
}

// Second authorization: 403 insufficient_scope with scope="write"
req2 := httptest.NewRequest(http.MethodGet, resourceURL, nil)
resp2 := &http.Response{
StatusCode: http.StatusForbidden,
Header: make(http.Header),
Body: http.NoBody,
}
resp2.Header.Set("WWW-Authenticate",
fmt.Sprintf(`Bearer error="insufficient_scope", scope="write", resource_metadata="%s/.well-known/oauth-protected-resource/resource"`, resourceServer.URL))
err = handler.Authorize(context.Background(), req2, resp2)
if err == nil || !strings.Contains(err.Error(), "stop after capturing URL") {
t.Fatalf("Second Authorize expected error containing 'stop after capturing URL', got: %v", err)
}

// Verify second auth URL accumulated both scopes.
secondURL, err := url.Parse(capturedAuthURLs[1])
if err != nil {
t.Fatalf("Failed to parse second auth URL: %v", err)
}
if got := secondURL.Query().Get("scope"); got != "read write" {
t.Errorf("Second auth scope = %q, want %q", got, "read write")
}
}

func TestAuthorize_ForbiddenUnhandledError(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/resource", nil)
resp := &http.Response{
Expand Down Expand Up @@ -740,6 +829,67 @@ func TestApplicationTypeInference(t *testing.T) {
}
}

func TestUnionScopes(t *testing.T) {
tests := []struct {
name string
existing []string
challenged []string
want []string
}{
{
name: "both empty",
existing: nil,
challenged: nil,
want: nil,
},
{
name: "existing only",
existing: []string{"read"},
challenged: nil,
want: []string{"read"},
},
{
name: "challenged only",
existing: nil,
challenged: []string{"write"},
want: []string{"write"},
},
{
name: "disjoint scopes",
existing: []string{"read"},
challenged: []string{"write"},
want: []string{"read", "write"},
},
{
name: "overlapping scopes",
existing: []string{"read", "write"},
challenged: []string{"write", "admin"},
want: []string{"read", "write", "admin"},
},
{
name: "identical scopes",
existing: []string{"read", "write"},
challenged: []string{"read", "write"},
want: []string{"read", "write"},
},
{
name: "preserves order",
existing: []string{"b", "a"},
challenged: []string{"c", "a"},
want: []string{"b", "a", "c"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := unionScopes(tt.existing, tt.challenged)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("unionScopes() mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestAuthorize_OfflineAccessScope(t *testing.T) {
tests := []struct {
name string
Expand Down
31 changes: 31 additions & 0 deletions auth/extauth/client_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/url"
"slices"
"strings"
"sync"

"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/oauthex"
Expand Down Expand Up @@ -47,6 +48,9 @@ type ClientCredentialsHandlerConfig struct {
type ClientCredentialsHandler struct {
config *ClientCredentialsHandlerConfig
tokenSource oauth2.TokenSource

mu sync.Mutex
requestedScopes []string
}

// Compile-time check that ClientCredentialsHandler implements auth.OAuthHandler.
Expand Down Expand Up @@ -127,6 +131,14 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ
scopes = prm.ScopesSupported
}

// Accumulate scopes: union previously requested scopes with the newly
// challenged scopes so that step-up authorization does not lose
// permissions granted in earlier rounds (SEP-2350).
h.mu.Lock()
scopes = unionScopes(h.requestedScopes, scopes)
h.requestedScopes = scopes
h.mu.Unlock()

// Step 3: Exchange client credentials for an access token.
creds := h.config.Credentials
cfg := &clientcredentials.Config{
Expand Down Expand Up @@ -229,6 +241,25 @@ func scopesFromChallenges(cs []oauthex.Challenge) []string {
return nil
}

// unionScopes returns the union of existing and challenged scopes,
// preserving order (existing first, then new challenged scopes).
func unionScopes(existing, challenged []string) []string {
if len(existing) == 0 {
return challenged
}
if len(challenged) == 0 {
return existing
}
result := make([]string, len(existing), len(existing)+len(challenged))
copy(result, existing)
for _, s := range challenged {
if !slices.Contains(result, s) {
result = append(result, s)
}
}
return result
}

// selectTokenAuthMethod selects the preferred token endpoint auth method based on
// the authorization server's supported methods. Prefers client_secret_post over
// client_secret_basic per the OAuth 2.1 draft.
Expand Down
Loading
Loading