diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aadfe091..6e1fab06 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ permissions: jobs: test: - timeout-minutes: 5 + timeout-minutes: 10 strategy: matrix: os: [macos-latest, ubuntu-latest] @@ -61,23 +61,19 @@ jobs: with: swift-version: 6.1.0 - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - - name: Build Swift executables run: | swift build --product mcp-everything-client swift build --product mcp-everything-server - name: Run client conformance tests - uses: modelcontextprotocol/conformance@v0.1.11 + uses: modelcontextprotocol/conformance@v0.1.15 with: mode: client command: '.build/debug/mcp-everything-client' suite: 'core' expected-failures: './conformance-baseline.yml' + node-version: '22' - name: Start server for testing run: | @@ -86,12 +82,13 @@ jobs: sleep 3 - name: Run server conformance tests - uses: modelcontextprotocol/conformance@v0.1.11 + uses: modelcontextprotocol/conformance@v0.1.15 with: mode: server url: 'http://localhost:3001/mcp' suite: 'core' expected-failures: './conformance-baseline.yml' + node-version: '22' - name: Cleanup server if: always() @@ -100,6 +97,18 @@ jobs: kill $SERVER_PID 2>/dev/null || true fi + documentation: + name: Documentation + runs-on: macos-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - uses: swift-actions/setup-swift@v2 + with: + swift-version: "6.1.0" + - name: Build Documentation + run: swift package generate-documentation --target MCP --warnings-as-errors + static-linux-sdk-build: name: Linux Static SDK Build (${{ matrix.swift-version }} - ${{ matrix.os }}) strategy: diff --git a/.spi.yml b/.spi.yml new file mode 100644 index 00000000..f8465813 --- /dev/null +++ b/.spi.yml @@ -0,0 +1,4 @@ +version: 1 +builder: + configs: + - documentation_targets: [MCP] \ No newline at end of file diff --git a/Package.resolved b/Package.resolved index b51f7109..d21ebacf 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "2e844b5aa785005b516db94bbdf89be795d57d958f648d3bf655b75ad9ad9d68", + "originHash" : "67fa9b6a0d1f334fd836cabb7af8c36110a6b88c5b49fa5a42a50125b18a056b", "pins" : [ { "identity" : "eventsource", @@ -28,6 +28,24 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-docc-plugin", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-docc-plugin", + "state" : { + "branch" : "main", + "revision" : "e977f65879f82b375a044c8837597f690c067da6" + } + }, + { + "identity" : "swift-docc-symbolkit", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-docc-symbolkit", + "state" : { + "revision" : "b45d1f2ed151d057b54504d653e0da5552844e34", + "version" : "1.0.0" + } + }, { "identity" : "swift-log", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 2647e8bd..273715ab 100644 --- a/Package.swift +++ b/Package.swift @@ -26,6 +26,7 @@ let package = Package( targets: ["MCPConformanceClient"]) ], dependencies: [ + .package(url: "https://github.com/swiftlang/swift-docc-plugin", branch: "main"), .package(url: "https://github.com/apple/swift-system.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.0"), .package(url: "https://github.com/mattt/eventsource.git", from: "1.1.0"), diff --git a/README.md b/README.md index ed6cc329..b4a82b1d 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,15 @@ of the MCP specification. - [Initialize Hook](#initialize-hook) - [Graceful Shutdown](#graceful-shutdown) - [Transports](#transports) +- [Authentication](#authentication) + - [Client: Client Credentials Flow](#client-client-credentials-flow) + - [Client: Authorization Code Flow](#client-authorization-code-flow) + - [Client: Custom Token Provider](#client-custom-token-provider) + - [Client: Custom Token Storage](#client-custom-token-storage) + - [Client: private\_key\_jwt Authentication](#client-private_key_jwt-authentication) + - [Client: Endpoint Overrides](#client-endpoint-overrides) + - [Server: Serving Protected Resource Metadata](#server-serving-protected-resource-metadata) + - [Server: Validating Bearer Tokens](#server-validating-bearer-tokens) - [Platform Availability](#platform-availability) - [Debugging and Logging](#debugging-and-logging) - [Additional Resources](#additional-resources) @@ -1341,6 +1350,195 @@ public actor MyCustomTransport: Transport { } ``` +## Authentication + +`HTTPClientTransport` supports OAuth 2.1 Bearer token authorization per the +[MCP authorization specification](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization). +When a server returns `401 Unauthorized` or `403 Forbidden`, the transport automatically: + +1. Discovers Protected Resource Metadata (RFC 9728) at `/.well-known/oauth-protected-resource` +2. Discovers Authorization Server Metadata (RFC 8414 / OIDC Discovery 1.0) +3. Registers the client dynamically (RFC 7591) if needed +4. Acquires a Bearer token using the configured grant flow (PKCE enforced) +5. Retries the original request with the token attached + +Authorization is opt-in and disabled by default. +Pass an `OAuthAuthorizer` to `HTTPClientTransport(authorizer:)` to enable it. + +### Client: Client Credentials Flow + +Machine-to-machine authentication using a pre-shared client secret: + +```swift +let config = OAuthConfiguration( + grantType: .clientCredentials, + authentication: .clientSecretBasic(clientID: "my-app", clientSecret: "s3cr3t") +) +let authorizer = OAuthAuthorizer(configuration: config) +let transport = HTTPClientTransport( + endpoint: URL(string: "https://api.example.com/mcp")!, + authorizer: authorizer +) +let client = Client(name: "MyClient", version: "1.0.0") +try await client.connect(transport: transport) +``` + +### Client: Authorization Code Flow + +Interactive, browser-based authentication with PKCE. +Implement `OAuthAuthorizationDelegate` to open the authorization URL and capture the redirect: + +```swift +struct MyAuthDelegate: OAuthAuthorizationDelegate { + func presentAuthorizationURL(_ url: URL) async throws -> URL { + // Open the URL in a browser/webview and wait for the callback redirect URI. + // The returned URL must include the authorization code and state parameters. + return try await openBrowserAndWaitForCallback(url) + } +} + +let config = OAuthConfiguration( + grantType: .authorizationCode, + authentication: .none(clientID: "my-app"), + authorizationDelegate: MyAuthDelegate() +) +let authorizer = OAuthAuthorizer(configuration: config) +let transport = HTTPClientTransport( + endpoint: URL(string: "https://api.example.com/mcp")!, + authorizer: authorizer +) +``` + +### Client: Custom Token Provider + +Supply an externally acquired token (e.g., from a system credential store) via `accessTokenProvider`. +The SDK calls this closure after discovery completes. Return `nil` to fall back to the configured grant flow: + +```swift +let config = OAuthConfiguration( + grantType: .clientCredentials, + authentication: .none(clientID: "my-app"), + accessTokenProvider: { context, session in + // context contains the discovered resource URI, token endpoint, scopes, etc. + return try await KeychainTokenStore.shared.loadToken(for: context.resource) + } +) +``` + +### Client: Custom Token Storage + +By default, tokens are stored in memory and lost when the process exits. +To persist tokens across sessions, implement `TokenStorage` and pass it to `OAuthAuthorizer`: + +```swift +final class KeychainTokenStorage: TokenStorage { + func save(_ token: OAuthAccessToken) { + // Encode and store token.value in the system Keychain + } + + func load() -> OAuthAccessToken? { + // Load and decode token from the Keychain + return nil + } + + func clear() { + // Delete from the Keychain + } +} + +let authorizer = OAuthAuthorizer( + configuration: config, + tokenStorage: KeychainTokenStorage() +) +``` + +### Client: `private_key_jwt` Authentication + +Authenticate to the token endpoint using an asymmetric key (RFC 7523). +The SDK provides a built-in ES256 helper for P-256 keys: + +```swift +let config = OAuthConfiguration( + grantType: .clientCredentials, + authentication: .privateKeyJWT( + clientID: "my-app", + assertionFactory: { tokenEndpoint, clientID in + try OAuthConfiguration.makePrivateKeyJWTAssertion( + clientID: clientID, + tokenEndpoint: tokenEndpoint, + privateKeyPEM: myEC256PrivateKeyPEM // PEM-encoded P-256 private key + ) + } + ) +) +``` + +### Client: Endpoint Overrides + +Skip automatic discovery by providing explicit endpoint URLs. +Useful when the server does not publish well-known metadata documents: + +```swift +let config = OAuthConfiguration( + grantType: .clientCredentials, + authentication: .clientSecretBasic(clientID: "app", clientSecret: "secret"), + endpointOverrides: OAuthConfiguration.EndpointOverrides( + tokenEndpoint: URL(string: "https://auth.example.com/oauth/token")! + ) +) +``` + +### Server: Serving Protected Resource Metadata + +Per the MCP authorization specification, servers **MUST** serve Protected Resource Metadata +at `/.well-known/oauth-protected-resource` so clients can discover authorization server endpoints. + +Use `ProtectedResourceMetadataValidator` as the first validator in your pipeline so that +unauthenticated discovery requests are handled before the bearer token check: + +```swift +let metadata = OAuthProtectedResourceServerMetadata( + resource: "https://api.example.com", + authorizationServers: [URL(string: "https://auth.example.com")!], + scopesSupported: ["read", "write"] +) +let metadataValidator = ProtectedResourceMetadataValidator(metadata: metadata) +``` + +### Server: Validating Bearer Tokens + +Use `BearerTokenValidator` to authenticate incoming requests. +Your `tokenValidator` closure **MUST** verify the token's `aud` claim to prevent +token substitution attacks where a token intended for another resource is replayed against your server: + +```swift +let resourceIdentifier = URL(string: "https://api.example.com")! + +let bearerValidator = BearerTokenValidator( + resourceMetadataURL: URL(string: "https://api.example.com/.well-known/oauth-protected-resource")!, + resourceIdentifier: resourceIdentifier, + tokenValidator: { token, request, context in + guard let claims = try? verifyAndDecodeJWT(token) else { + return .invalidToken(errorDescription: "Token verification failed") + } + // Pass audience and expiry to BearerTokenInfo; the SDK validates the + // audience claim against resourceIdentifier automatically. + return .valid(BearerTokenInfo( + audience: claims.audience, + expiresAt: claims.expiresAt + )) + } +) + +let pipeline = StandardValidationPipeline(validators: [ + metadataValidator, // serves /.well-known/oauth-protected-resource unauthenticated + bearerValidator, // validates Bearer tokens on all other requests + AcceptHeaderValidator(mode: .sseRequired), + ContentTypeValidator(), + SessionValidator(), +]) +``` + ## Platform Availability The Swift SDK has the following platform requirements: diff --git a/Sources/MCP/Base/Authorization/OAuthAuthorizationCodeFlow.swift b/Sources/MCP/Base/Authorization/OAuthAuthorizationCodeFlow.swift new file mode 100644 index 00000000..0c278f1c --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthAuthorizationCodeFlow.swift @@ -0,0 +1,253 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +// MARK: - OAuthAuthorizationCodeFlowing Protocol + +/// Internal protocol for driving the OAuth 2.1 authorization code flow. +protocol OAuthAuthorizationCodeFlowing: Sendable { + func buildURL( + authorizationEndpoint: URL, + resource: URL, + redirectURI: URL, + clientID: String, + codeChallenge: String, + scopes: Set?, + state: String, + scopeSerializer: any OAuthScopeSelecting + ) throws -> URL + + func perform( + authorizationURL: URL, + redirectURI: URL, + state: String, + delegate: (any OAuthAuthorizationDelegate)?, + session: URLSession + ) async throws -> String +} + +// MARK: - No-Redirect Session Delegate + +final class OAuthNoRedirectSessionDelegate: NSObject, URLSessionTaskDelegate, + @unchecked Sendable +{ + func urlSession( + _ session: URLSession, + task: URLSessionTask, + willPerformHTTPRedirection response: HTTPURLResponse, + newRequest request: URLRequest, + completionHandler: @escaping (URLRequest?) -> Void + ) { + completionHandler(nil) + } +} + +// MARK: - Authorization Code Flow + +/// Handles the browser-facing steps of the OAuth 2.1 authorization_code flow. +/// +/// Builds the authorization request URL, drives the redirect (via delegate or direct HTTP), +/// and extracts the authorization code from the callback redirect URL. +public struct OAuthAuthorizationCodeFlow: Sendable { + + public init() {} + + /// Builds the authorization request URL. + /// + /// - Parameters: + /// - authorizationEndpoint: The AS authorization endpoint. + /// - resource: The RFC 8707 resource indicator. + /// - redirectURI: The redirect URI registered for this client. + /// - clientID: The OAuth client identifier. + /// - codeChallenge: The PKCE S256 code challenge. + /// - scopes: Optional scope set to request. + /// - state: The CSRF state nonce. + /// - scopeSerializer: Serializes the scope set to a space-separated string. + /// - Returns: The full authorization request URL with all query parameters. + public func buildURL( + authorizationEndpoint: URL, + resource: URL, + redirectURI: URL, + clientID: String, + codeChallenge: String, + scopes: Set?, + state: String, + scopeSerializer: any OAuthScopeSelecting + ) throws -> URL { + guard var components = URLComponents( + url: authorizationEndpoint, + resolvingAgainstBaseURL: false + ) else { + throw OAuthAuthorizationError.authorizationServerMetadataDiscoveryFailed + } + + var queryItems: [URLQueryItem] = [ + .init(name: OAuthParameterName.responseType, value: OAuthParameterName.code), + .init(name: OAuthParameterName.clientID, value: clientID), + .init(name: OAuthParameterName.redirectURI, value: redirectURI.absoluteString), + .init(name: OAuthParameterName.state, value: state), + .init(name: OAuthParameterName.resource, value: resource.absoluteString), + .init(name: OAuthParameterName.codeChallenge, value: codeChallenge), + .init( + name: OAuthParameterName.codeChallengeMethod, value: OAuthCodeChallengeMethod.s256), + ] + + if let scope = scopes.flatMap(scopeSerializer.serialize) { + queryItems.append(.init(name: OAuthParameterName.scope, value: scope)) + } + + components.queryItems = queryItems + guard let url = components.url else { + throw OAuthAuthorizationError.authorizationServerMetadataDiscoveryFailed + } + return url + } + + /// Drives the interactive authorization redirect and returns the authorization code. + /// + /// When a delegate is provided, presents the authorization URL and awaits the redirect. + /// Without a delegate, sends a GET request and captures the redirect `Location` header. + /// + /// - Parameters: + /// - authorizationURL: The full authorization request URL. + /// - redirectURI: The expected redirect URI base for validation. + /// - state: The CSRF state nonce to verify in the redirect. + /// - delegate: Optional user-facing delegate for browser-based flows. + /// - session: The `URLSession` used for the no-redirect path. + /// - Returns: The extracted authorization code. + public func perform( + authorizationURL: URL, + redirectURI: URL, + state: String, + delegate: (any OAuthAuthorizationDelegate)?, + session: URLSession + ) async throws -> String { + if let delegate { + let redirectURL = try await delegate.presentAuthorizationURL(authorizationURL) + return try extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: state + ) + } + + var request = URLRequest(url: authorizationURL) + request.httpMethod = "GET" + request.setValue( + "text/html, \(ContentType.json)", forHTTPHeaderField: HTTPHeaderName.accept) + + let noRedirectDelegate = OAuthNoRedirectSessionDelegate() + let noRedirectSession = URLSession( + configuration: session.configuration, + delegate: noRedirectDelegate, + delegateQueue: nil + ) + defer { noRedirectSession.invalidateAndCancel() } + + let (_, response) = try await noRedirectSession.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw OAuthAuthorizationError.authorizationResponseMissingRedirectLocation + } + + guard (300..<400).contains(httpResponse.statusCode) else { + if httpResponse.statusCode >= 400 { + throw OAuthAuthorizationError.authorizationRequestFailed(statusCode: httpResponse.statusCode) + } + throw OAuthAuthorizationError.authorizationResponseMissingRedirectLocation + } + + guard let location = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.location), + !location.isEmpty, + let redirectURL = URL(string: location) + else { + throw OAuthAuthorizationError.authorizationResponseMissingRedirectLocation + } + + return try extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: state + ) + } + + /// Extracts and validates the authorization code from the redirect URL. + /// + /// - Parameters: + /// - redirectURL: The redirect URL received from the authorization server. + /// - expectedRedirectURI: The redirect URI used in the authorization request. + /// - expectedState: The CSRF state nonce sent in the authorization request. + /// - Returns: The authorization code. + public func extractCode( + from redirectURL: URL, + expectedRedirectURI: URL, + expectedState: String + ) throws -> String { + guard + let redirectComponents = URLComponents(url: redirectURL, resolvingAgainstBaseURL: false), + let expectedComponents = URLComponents( + url: expectedRedirectURI, resolvingAgainstBaseURL: false) + else { + throw OAuthAuthorizationError.authorizationResponseRedirectMismatch( + expected: expectedRedirectURI.absoluteString, + actual: redirectURL.absoluteString + ) + } + + if normalizedRedirectBase(redirectComponents) != normalizedRedirectBase(expectedComponents) { + throw OAuthAuthorizationError.authorizationResponseRedirectMismatch( + expected: expectedRedirectURI.absoluteString, + actual: redirectURL.absoluteString + ) + } + + guard + let state = redirectComponents.queryItems?.first(where: { + $0.name == OAuthParameterName.state + })?.value, + !state.isEmpty + else { + throw OAuthAuthorizationError.authorizationResponseMissingState + } + + guard state == expectedState else { + throw OAuthAuthorizationError.authorizationResponseStateMismatch( + expected: expectedState, + actual: state + ) + } + + guard + let code = redirectComponents.queryItems?.first(where: { + $0.name == OAuthParameterName.code + })?.value, + !code.isEmpty + else { + throw OAuthAuthorizationError.authorizationResponseMissingCode + } + + return code + } + + // MARK: - Private Helpers + + private func normalizedRedirectBase(_ components: URLComponents) -> String { + let scheme = components.scheme?.lowercased() ?? "" + let host = components.host?.lowercased() ?? "" + let port: Int + if let explicitPort = components.port { + port = explicitPort + } else if scheme == OAuthURLScheme.https { + port = OAuthDefaultPort.https + } else if scheme == OAuthURLScheme.http { + port = OAuthDefaultPort.http + } else { + port = -1 + } + let path = components.path.isEmpty ? "/" : components.path + return "\(scheme)://\(host):\(port)\(path)" + } +} + +extension OAuthAuthorizationCodeFlow: OAuthAuthorizationCodeFlowing {} diff --git a/Sources/MCP/Base/Authorization/OAuthAuthorizer.swift b/Sources/MCP/Base/Authorization/OAuthAuthorizer.swift new file mode 100644 index 00000000..2cac3495 --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthAuthorizer.swift @@ -0,0 +1,865 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if canImport(CryptoKit) + import CryptoKit +#endif + +// MARK: - HTTPClientAuthorizer Protocol + +/// Abstraction used by ``HTTPClientTransport`` to handle OAuth authorization challenges. +/// +/// Implement this protocol to provide custom token acquisition strategies, +/// or use the built-in ``OAuthAuthorizer`` for a full OAuth 2.1 implementation. +/// +/// ``HTTPClientTransport`` calls these methods automatically when it receives +/// `401 Unauthorized` or `403 Forbidden` responses from the server. +public protocol HTTPClientAuthorizer: AnyObject, Sendable { + /// The maximum number of authorization retries permitted for a single request. + /// + /// ``HTTPClientTransport`` will not call ``handleChallenge(statusCode:headers:endpoint:operationKey:session:)`` + /// more than this many times for a single outgoing request. + var maxAuthorizationAttempts: Int { get } + + /// Validates that the MCP endpoint URL satisfies the security requirements for OAuth. + /// + /// Called once before the first request is sent. Throw ``OAuthAuthorizationError/insecureOAuthEndpoint(context:url:)`` + /// if the URL does not meet the requirements (e.g., non-HTTPS non-loopback). + /// - Parameter endpoint: The MCP endpoint URL to validate. + func validateEndpointSecurity(for endpoint: URL) throws + + /// Returns the `Authorization` header value to attach to the next request, if a valid token is available. + /// + /// - Parameter endpoint: The MCP endpoint being requested. + /// - Returns: A `"Bearer "` string, or `nil` if no valid token is cached. + func authorizationHeader(for endpoint: URL) -> String? + + /// Handles an authorization challenge received from the server and attempts to acquire a new token. + /// + /// Called by ``HTTPClientTransport`` when a `401` or `403` response is received. + /// The implementation should attempt to obtain a valid access token and store it + /// so that a subsequent call to ``authorizationHeader(for:)`` returns the new value. + /// + /// - Parameters: + /// - statusCode: The HTTP status code (401 or 403). + /// - headers: All response headers from the challenge response. + /// - endpoint: The MCP endpoint that returned the challenge. + /// - operationKey: An optional identifier for the MCP operation (e.g., the JSON-RPC method), + /// used to track step-up attempts per operation. + /// - session: The `URLSession` to use for discovery and token requests. + /// - Returns: `true` if a new token was acquired and the original request should be retried; + /// `false` if the challenge cannot be handled. + func handleChallenge( + statusCode: Int, + headers: [String: String], + endpoint: URL, + operationKey: String?, + session: URLSession + ) async throws -> Bool + + /// Proactively refreshes the access token if it is close to expiry. + /// + /// Called by ``HTTPClientTransport`` before sending each request and before opening + /// an SSE stream, allowing the token to be silently renewed without a 401 round-trip. + /// Implementations should swallow refresh errors — if refresh fails, the normal + /// ``handleChallenge(statusCode:headers:endpoint:operationKey:session:)`` path recovers. + /// + /// - Parameters: + /// - endpoint: The MCP endpoint about to be contacted. + /// - session: The `URLSession` to use for token refresh requests. + func prepareAuthorization(for endpoint: URL, session: URLSession) async throws +} + +extension HTTPClientAuthorizer { + public func prepareAuthorization(for endpoint: URL, session: URLSession) async throws {} +} + +// MARK: - OAuthAuthorizer + +/// Full OAuth 2.1 implementation of ``HTTPClientAuthorizer``. +/// +/// `OAuthAuthorizer` orchestrates the complete MCP authorization flow on behalf of an HTTP client: +/// +/// 1. **Protected Resource Metadata discovery** (RFC 9728) — fetches +/// `/.well-known/oauth-protected-resource` to locate the authorization server. +/// 2. **Authorization Server Metadata discovery** (RFC 8414 / OIDC Discovery 1.0) — fetches +/// `/.well-known/oauth-authorization-server` or `/.well-known/openid-configuration`. +/// 3. **Dynamic Client Registration** (RFC 7591) — registers the client if no credentials are +/// pre-configured and the AS advertises a registration endpoint. +/// 4. **Token acquisition** — performs the configured grant flow (`authorization_code` with PKCE, +/// or `client_credentials`), binding tokens to the resource indicator (RFC 8707). +/// 5. **Token refresh** — attempts a `refresh_token` grant before a full re-authorization. +/// 6. **Scope step-up** — handles `403 insufficient_scope` challenges by re-requesting with +/// the union of existing and required scopes. +/// +/// Pass an instance to `HTTPClientTransport(authorizer:)` to enable automatic authorization: +/// +/// ```swift +/// let config = OAuthConfiguration( +/// grantType: .clientCredentials, +/// authentication: .clientSecretBasic(clientID: "my-app", clientSecret: "s3cr3t") +/// ) +/// let authorizer = OAuthAuthorizer(configuration: config) +/// let transport = HTTPClientTransport(endpoint: serverURL, authorizer: authorizer) +/// ``` +/// +/// - Important: This type is `@unchecked Sendable`. All mutable state is accessed +/// exclusively through the `HTTPClientTransport` actor, which serializes every call. +/// Do **not** share a single `OAuthAuthorizer` instance across multiple transports — +/// doing so would violate the isolation contract and risk concurrent mutation. +public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable { + + // MARK: - Mutable State + + private var configuration: OAuthConfiguration + private let tokenStorage: TokenStorage + private var selectedAuthorizationServer: URL? + private var protectedResourceMetadata: OAuthProtectedResourceMetadata? + private var authorizationServerMetadata: OAuthAuthorizationServerMetadata? + private var cachedProtectedResourceMetadataURL: URL? + private var stepUpAttempts: [String: Int] = [:] + private var clientRegistrationAttempted = false + private var clientSecretExpiresAt: Date? + + // MARK: - Composable Dependencies + + private let scopeSelector: any OAuthScopeSelecting + private let challengeParser: any OAuthWWWAuthenticateParsing + private let urlValidator: any OAuthURLValidating + private let discoveryClient: any OAuthDiscoveryFetching + private let tokenEndpointClient: any OAuthTokenRequesting + private let clientRegistrar: any OAuthClientRegistering + private let authCodeFlow: any OAuthAuthorizationCodeFlowing + + /// Creates an `OAuthAuthorizer` with the given configuration and optional injectable dependencies. + /// + /// - Parameters: + /// - configuration: OAuth 2.1 configuration controlling the grant type, authentication method, + /// endpoint discovery overrides, and retry policy. + /// - tokenStorage: Stores acquired access tokens. Defaults to ``InMemoryTokenStorage``, + /// which loses tokens when the process exits. Supply a Keychain-backed implementation + /// to persist tokens across sessions. + /// - scopeSelector: Strategy for selecting OAuth scopes from challenge and metadata hints. + /// Defaults to ``DefaultOAuthScopeSelector``. + /// - challengeParser: Parses `WWW-Authenticate: Bearer` challenge headers. + /// Defaults to ``DefaultOAuthWWWAuthenticateParser``. + /// - metadataDiscovery: Constructs well-known discovery URLs and validates resource URI matching. + /// Defaults to ``DefaultOAuthMetadataDiscovery``. + public convenience init( + configuration: OAuthConfiguration, + tokenStorage: TokenStorage = InMemoryTokenStorage(), + scopeSelector: any OAuthScopeSelecting = DefaultOAuthScopeSelector(), + challengeParser: any OAuthWWWAuthenticateParsing = DefaultOAuthWWWAuthenticateParser(), + metadataDiscovery: any OAuthMetadataDiscovering = DefaultOAuthMetadataDiscovery() + ) { + let urlValidator = OAuthURLValidator( + allowLoopbackHTTPForAuthorizationServer: + configuration.allowLoopbackHTTPAuthorizationServerEndpoints + ) + self.init( + configuration: configuration, + tokenStorage: tokenStorage, + scopeSelector: scopeSelector, + challengeParser: challengeParser, + urlValidator: urlValidator, + discoveryClient: OAuthDiscoveryClient( + metadataDiscovery: metadataDiscovery, + urlValidator: urlValidator + ), + tokenEndpointClient: OAuthTokenEndpointClient(urlValidator: urlValidator), + clientRegistrar: OAuthClientRegistrar(urlValidator: urlValidator), + authCodeFlow: OAuthAuthorizationCodeFlow() + ) + } + + init( + configuration: OAuthConfiguration, + tokenStorage: TokenStorage = InMemoryTokenStorage(), + scopeSelector: any OAuthScopeSelecting = DefaultOAuthScopeSelector(), + challengeParser: any OAuthWWWAuthenticateParsing = DefaultOAuthWWWAuthenticateParser(), + urlValidator: any OAuthURLValidating, + discoveryClient: any OAuthDiscoveryFetching, + tokenEndpointClient: any OAuthTokenRequesting, + clientRegistrar: any OAuthClientRegistering, + authCodeFlow: any OAuthAuthorizationCodeFlowing + ) { + self.configuration = configuration + self.tokenStorage = tokenStorage + self.scopeSelector = scopeSelector + self.challengeParser = challengeParser + self.urlValidator = urlValidator + self.discoveryClient = discoveryClient + self.tokenEndpointClient = tokenEndpointClient + self.clientRegistrar = clientRegistrar + self.authCodeFlow = authCodeFlow + } + + // MARK: - HTTPClientAuthorizer + + public var maxAuthorizationAttempts: Int { + configuration.retryPolicy.maxAuthorizationAttempts + } + + public func validateEndpointSecurity(for endpoint: URL) throws { + try urlValidator.validateHTTPSOrLoopback(endpoint, context: "MCP endpoint") + } + + public func authorizationHeader(for endpoint: URL) -> String? { + guard let accessToken = tokenStorage.load() else { return nil } + if let tokenAuthorizationServer = accessToken.authorizationServer, + let selectedAuthorizationServer, + !authorizationServersMatch(tokenAuthorizationServer, selectedAuthorizationServer) + { + tokenStorage.clear() + return nil + } + if accessToken.isExpired() { + tokenStorage.clear() + return nil + } + return "\(OAuthTokenType.bearer) \(accessToken.value)" + } + + public func handleChallenge( + statusCode: Int, + headers: [String: String], + endpoint: URL, + operationKey: String? = nil, + session: URLSession + ) async throws -> Bool { + try validateEndpointSecurity(for: endpoint) + let challenge = challengeParser.parseBearer(from: headers) + + switch statusCode { + case 401: + if let refreshToken = tokenStorage.load()?.refreshToken { + tokenStorage.clear() + let metadata = try await discoverProtectedResourceMetadata( + endpoint: endpoint, + challenge: challenge, + session: session + ) + let asMetadata = try await resolveAuthorizationServerMetadata( + metadata: metadata, + session: session + ) + let resource = try canonicalResource(for: endpoint) + let requestedScopes = scopeSelector.selectScopes( + challengeScope: challenge?.scope, + scopesSupported: metadata.scopesSupported + ) + if try await refreshAccessToken( + refreshToken: refreshToken, + resource: resource, + requestedScopes: requestedScopes, + asMetadata: asMetadata, + session: session + ) { + return true + } + } else { + tokenStorage.clear() + } + + let metadata = try await discoverProtectedResourceMetadata( + endpoint: endpoint, + challenge: challenge, + session: session + ) + let requestedScopes = scopeSelector.selectScopes( + challengeScope: challenge?.scope, + scopesSupported: metadata.scopesSupported + ) + + let providerContext = try await makeAccessTokenProviderContext( + statusCode: statusCode, + endpoint: endpoint, + challenge: challenge, + metadata: metadata, + requestedScopes: requestedScopes, + session: session + ) + if let externalToken = try await fetchAccessTokenFromProvider( + context: providerContext, + session: session + ) { + storeExternalAccessToken( + externalToken, + requestedScopes: providerContext.requestedScopes, + authorizationServer: providerContext.authorizationServer + ) + return true + } + + try await acquireToken( + endpoint: endpoint, + metadata: metadata, + requestedScopes: requestedScopes, + session: session + ) + return true + + case 403: + guard challenge?.error?.lowercased() == "insufficient_scope" else { return false } + + let metadata = try await discoverProtectedResourceMetadata( + endpoint: endpoint, + challenge: challenge, + session: session + ) + let requiredScopes = + scopeSelector.selectScopes( + challengeScope: challenge?.scope, + scopesSupported: metadata.scopesSupported + ) ?? [] + + let existingScopes = tokenStorage.load()?.scopes ?? [] + let upgradedScopes = existingScopes.union(requiredScopes) + let resourceKey = try discoveryClient.metadataDiscovery.canonicalResourceURI( + from: endpoint + ).absoluteString + let operationAttemptKey = normalizedOperationKey(operationKey) + let attemptKey = + "\(resourceKey)|\(operationAttemptKey)|\(upgradedScopes.sorted().joined(separator: " "))" + let attempts = stepUpAttempts[attemptKey, default: 0] + guard attempts < configuration.retryPolicy.maxScopeUpgradeAttempts else { + return false + } + stepUpAttempts[attemptKey] = attempts + 1 + + let providerRequestedScopes = upgradedScopes.isEmpty ? nil : upgradedScopes + let providerContext = try await makeAccessTokenProviderContext( + statusCode: statusCode, + endpoint: endpoint, + challenge: challenge, + metadata: metadata, + requestedScopes: providerRequestedScopes, + session: session + ) + if let externalToken = try await fetchAccessTokenFromProvider( + context: providerContext, + session: session + ) { + storeExternalAccessToken( + externalToken, + requestedScopes: providerContext.requestedScopes, + authorizationServer: providerContext.authorizationServer + ) + return true + } + + try await acquireToken( + endpoint: endpoint, + metadata: metadata, + requestedScopes: upgradedScopes, + session: session + ) + return true + + default: + return false + } + } + + public func prepareAuthorization(for endpoint: URL, session: URLSession) async throws { + guard configuration.proactiveRefreshWindowSeconds > 0 else { return } + guard let token = tokenStorage.load() else { return } + guard !token.isExpired() else { return } + guard token.isExpired(skewSeconds: configuration.proactiveRefreshWindowSeconds) else { + return + } + guard let refreshToken = token.refreshToken else { return } + guard let asMeta = authorizationServerMetadata, asMeta.tokenEndpoint != nil else { return } + + let resource: URL + do { + resource = try canonicalResource(for: endpoint) + } catch { + return + } + + let requestedScopes = token.scopes.isEmpty ? nil : token.scopes + _ = try? await refreshAccessToken( + refreshToken: refreshToken, + resource: resource, + requestedScopes: requestedScopes, + asMetadata: asMeta, + session: session + ) + } + + // MARK: - Discovery + + private func discoverProtectedResourceMetadata( + endpoint: URL, + challenge: OAuthBearerChallenge?, + session: URLSession + ) async throws -> OAuthProtectedResourceMetadata { + if let protectedResourceMetadata { + let incomingURL = challenge?.resourceMetadataURL + if let incomingURL, incomingURL != cachedProtectedResourceMetadataURL { + self.protectedResourceMetadata = nil + self.authorizationServerMetadata = nil + self.selectedAuthorizationServer = nil + self.cachedProtectedResourceMetadataURL = nil + } else { + return protectedResourceMetadata + } + } + + var candidates: [URL] = [] + + if let challengeURL = challenge?.resourceMetadataURL { + try urlValidator.validateHTTPSOrLoopback( + challengeURL, context: "Protected resource metadata URL") + if let host = URLComponents(url: challengeURL, resolvingAgainstBaseURL: false)?.host? + .lowercased(), urlValidator.isPrivateIPHost(host) + { + throw OAuthAuthorizationError.privateIPAddressBlocked( + context: "Protected resource metadata URL", + url: challengeURL.absoluteString + ) + } + candidates.append(challengeURL) + } + if let configuredURL = configuration.endpointOverrides.protectedResourceMetadataURL, + !candidates.contains(configuredURL) + { + try urlValidator.validateHTTPSOrLoopback( + configuredURL, + context: "Configured protected resource metadata URL" + ) + candidates.append(configuredURL) + } + + for fallback in discoveryClient.metadataDiscovery.protectedResourceMetadataURLs( + for: endpoint) + where !candidates.contains(fallback) { + candidates.append(fallback) + } + + let metadata = try await discoveryClient.fetchProtectedResourceMetadata( + candidates: candidates, session: session) + try validateProtectedResource(metadata: metadata, endpoint: endpoint) + + self.protectedResourceMetadata = metadata + self.cachedProtectedResourceMetadataURL = candidates.first + return metadata + } + + private func validateProtectedResource( + metadata: OAuthProtectedResourceMetadata, endpoint: URL + ) throws { + guard let resource = metadata.resource?.trimmingCharacters(in: .whitespacesAndNewlines), + !resource.isEmpty + else { + return + } + + guard let resourceURL = URL(string: resource) else { + throw OAuthAuthorizationError.invalidResourceURI( + "Protected resource metadata contains an invalid resource URI: \(resource)" + ) + } + + let expected = try discoveryClient.metadataDiscovery.canonicalResourceURI(from: endpoint) + let actual = try discoveryClient.metadataDiscovery.canonicalResourceURI(from: resourceURL) + guard discoveryClient.metadataDiscovery.protectedResourceMatches( + resource: actual, endpoint: expected) + else { + throw OAuthAuthorizationError.protectedResourceMismatch( + expected: expected.absoluteString, + actual: actual.absoluteString + ) + } + } + + private func resolveAuthorizationServerMetadata( + metadata: OAuthProtectedResourceMetadata, + session: URLSession + ) async throws -> OAuthAuthorizationServerMetadata { + if let cached = authorizationServerMetadata { + return cached + } + + let candidates: [URL] + if let override = configuration.endpointOverrides.authorizationServerURL { + try urlValidator.validateAuthorizationServer( + override, context: "Authorization server issuer") + candidates = [override] + } else if let selected = selectedAuthorizationServer { + candidates = [selected] + } else { + guard !metadata.authorizationServers.isEmpty else { + throw OAuthAuthorizationError.missingAuthorizationServer + } + candidates = metadata.authorizationServers + } + + let (server, asMetadata) = try await discoveryClient.fetchAuthorizationServerMetadata( + candidates: candidates, session: session) + self.selectedAuthorizationServer = server + self.authorizationServerMetadata = asMetadata + return asMetadata + } + + // MARK: - Token Acquisition + + private func acquireToken( + endpoint: URL, + metadata: OAuthProtectedResourceMetadata, + requestedScopes: Set?, + session: URLSession + ) async throws { + let asMetadata = try await resolveAuthorizationServerMetadata( + metadata: metadata, session: session) + try await maybeRegisterClient(asMetadata: asMetadata, session: session) + let resource = try canonicalResource(for: endpoint) + + switch configuration.grantType { + case .clientCredentials: + try await acquireTokenViaClientCredentials( + resource: resource, + requestedScopes: requestedScopes, + asMetadata: asMetadata, + session: session + ) + case .authorizationCode: + try await acquireTokenViaAuthorizationCode( + resource: resource, + requestedScopes: requestedScopes, + asMetadata: asMetadata, + session: session + ) + } + } + + private func acquireTokenViaClientCredentials( + resource: URL, + requestedScopes: Set?, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws { + let tokenEndpoint = try resolveTokenEndpoint(asMetadata: asMetadata) + var bodyParameters: [String: String] = configuration.additionalTokenRequestParameters + bodyParameters[OAuthParameterName.grantType] = OAuthGrantTypeValue.clientCredentials + bodyParameters[OAuthParameterName.resource] = resource.absoluteString + if let scope = requestedScopes.flatMap(scopeSelector.serialize) { + bodyParameters[OAuthParameterName.scope] = scope + } + let decoded = try await tokenEndpointClient.request( + parameters: &bodyParameters, + endpoint: tokenEndpoint, + authentication: configuration.authentication, + session: session + ) + storeTokenResponse(decoded, requestedScopes: requestedScopes) + } + + private func acquireTokenViaAuthorizationCode( + resource: URL, + requestedScopes: Set?, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws { + guard let authorizationEndpoint = asMetadata.authorizationEndpoint else { + throw OAuthAuthorizationError.tokenEndpointMissing + } + try urlValidator.validateAuthorizationServer( + authorizationEndpoint, context: "Authorization endpoint") + if let host = URLComponents(url: authorizationEndpoint, resolvingAgainstBaseURL: false)? + .host?.lowercased(), urlValidator.isPrivateIPHost(host) + { + throw OAuthAuthorizationError.privateIPAddressBlocked( + context: "Authorization endpoint", + url: authorizationEndpoint.absoluteString + ) + } + try urlValidator.validateRedirectURI(configuration.authorizationRedirectURI) + try PKCE.checkSupport(in: asMetadata) + + let verifier = PKCE.makeVerifier() + let challenge = try PKCE.makeChallenge(from: verifier) + let state = UUID().uuidString + + let authorizationURL = try authCodeFlow.buildURL( + authorizationEndpoint: authorizationEndpoint, + resource: resource, + redirectURI: configuration.authorizationRedirectURI, + clientID: configuration.authentication.clientID, + codeChallenge: challenge, + scopes: requestedScopes, + state: state, + scopeSerializer: scopeSelector + ) + + let authorizationCode = try await authCodeFlow.perform( + authorizationURL: authorizationURL, + redirectURI: configuration.authorizationRedirectURI, + state: state, + delegate: configuration.authorizationDelegate, + session: session + ) + + let tokenEndpoint = try resolveTokenEndpoint(asMetadata: asMetadata) + var bodyParameters: [String: String] = configuration.additionalTokenRequestParameters + bodyParameters[OAuthParameterName.grantType] = OAuthGrantTypeValue.authorizationCode + bodyParameters[OAuthParameterName.code] = authorizationCode + bodyParameters[OAuthParameterName.codeVerifier] = verifier + bodyParameters[OAuthParameterName.redirectURI] = + configuration.authorizationRedirectURI.absoluteString + bodyParameters[OAuthParameterName.resource] = resource.absoluteString + if let scope = requestedScopes.flatMap(scopeSelector.serialize) { + bodyParameters[OAuthParameterName.scope] = scope + } + + let decoded = try await tokenEndpointClient.request( + parameters: &bodyParameters, + endpoint: tokenEndpoint, + authentication: configuration.authentication, + session: session + ) + storeTokenResponse(decoded, requestedScopes: requestedScopes) + } + + // MARK: - Token Refresh + + private func refreshAccessToken( + refreshToken: String, + resource: URL, + requestedScopes: Set?, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws -> Bool { + let tokenEndpoint: URL + do { + tokenEndpoint = try resolveTokenEndpoint(asMetadata: asMetadata) + } catch { + return false + } + + var bodyParameters: [String: String] = configuration.additionalTokenRequestParameters + bodyParameters[OAuthParameterName.grantType] = OAuthGrantTypeValue.refreshToken + bodyParameters[OAuthParameterName.refreshToken] = refreshToken + bodyParameters[OAuthParameterName.resource] = resource.absoluteString + if let scope = requestedScopes.flatMap(scopeSelector.serialize) { + bodyParameters[OAuthParameterName.scope] = scope + } + + do { + let decoded = try await tokenEndpointClient.request( + parameters: &bodyParameters, + endpoint: tokenEndpoint, + authentication: configuration.authentication, + session: session + ) + storeTokenResponse(decoded, requestedScopes: requestedScopes) + return true + } catch let error as OAuthAuthorizationError { + if case .tokenRequestFailed(let statusCode, _) = error, + (400..<500).contains(statusCode) + { + return false + } + throw error + } + } + + // MARK: - Client Registration + + private func maybeRegisterClient( + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws { + if let expiry = clientSecretExpiresAt, Date() >= expiry { + clientSecretExpiresAt = nil + clientRegistrationAttempted = false + configuration.authentication = .none(clientID: configuration.authentication.clientID) + } + + guard !clientRegistrationAttempted else { return } + guard case .none = configuration.authentication else { return } + + clientRegistrationAttempted = true + + if let (registration, updatedAuth) = try await clientRegistrar.register( + configuration: configuration, + asMetadata: asMetadata, + session: session + ) { + configuration.authentication = updatedAuth + if let expiresAt = registration.clientSecretExpiresAt, expiresAt > 0 { + clientSecretExpiresAt = Date(timeIntervalSince1970: Double(expiresAt)) + } + } + } + + // MARK: - State Helpers + + private func storeTokenResponse( + _ decoded: OAuthTokenResponse, + requestedScopes: Set? + ) { + let scopeSet: Set + if let scope = decoded.scope { + scopeSet = scopeSelector.parseScopeString(scope) + } else { + scopeSet = requestedScopes ?? [] + } + let expiresAt = decoded.expiresIn.map { Date().addingTimeInterval(TimeInterval($0)) } + tokenStorage.save(OAuthAccessToken( + value: decoded.accessToken, + tokenType: OAuthTokenType.bearer, + expiresAt: expiresAt, + scopes: scopeSet, + authorizationServer: selectedAuthorizationServer, + refreshToken: decoded.refreshToken + )) + } + + private func resolveTokenEndpoint( + asMetadata: OAuthAuthorizationServerMetadata + ) throws -> URL { + if let configuredEndpoint = configuration.endpointOverrides.tokenEndpoint { + try urlValidator.validateAuthorizationServer( + configuredEndpoint, context: "Configured token endpoint") + return configuredEndpoint + } + + guard let tokenEndpoint = asMetadata.tokenEndpoint else { + throw OAuthAuthorizationError.tokenEndpointMissing + } + try urlValidator.validateAuthorizationServer(tokenEndpoint, context: "Token endpoint") + if let host = URLComponents(url: tokenEndpoint, resolvingAgainstBaseURL: false)?.host? + .lowercased(), urlValidator.isPrivateIPHost(host) + { + throw OAuthAuthorizationError.privateIPAddressBlocked( + context: "Token endpoint", + url: tokenEndpoint.absoluteString + ) + } + return tokenEndpoint + } + + private func canonicalResource(for endpoint: URL) throws -> URL { + let endpointCanonical = try discoveryClient.metadataDiscovery.canonicalResourceURI( + from: endpoint) + + if let configuredResource = configuration.endpointOverrides.resource { + let configuredCanonical = try discoveryClient.metadataDiscovery.canonicalResourceURI( + from: configuredResource) + guard discoveryClient.metadataDiscovery.protectedResourceMatches( + resource: configuredCanonical, endpoint: endpointCanonical) + else { + throw OAuthAuthorizationError.protectedResourceMismatch( + expected: endpointCanonical.absoluteString, + actual: configuredCanonical.absoluteString + ) + } + return configuredCanonical + } + + if let prmResourceString = protectedResourceMetadata?.resource, + let prmResourceURL = URL(string: prmResourceString) + { + let prmCanonical = try discoveryClient.metadataDiscovery.canonicalResourceURI( + from: prmResourceURL) + guard discoveryClient.metadataDiscovery.protectedResourceMatches( + resource: prmCanonical, endpoint: endpointCanonical) + else { + throw OAuthAuthorizationError.protectedResourceMismatch( + expected: endpointCanonical.absoluteString, + actual: prmCanonical.absoluteString + ) + } + return prmCanonical + } + + return endpointCanonical + } + + private func authorizationServersMatch(_ lhs: URL, _ rhs: URL) -> Bool { + normalizedAuthorizationServer(lhs) == normalizedAuthorizationServer(rhs) + } + + private func normalizedAuthorizationServer(_ url: URL) -> URL? { + guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + scheme == OAuthURLScheme.http || scheme == OAuthURLScheme.https + else { + return nil + } + components.scheme = scheme + components.host = host + components.query = nil + components.fragment = nil + if components.path == "/" { components.path = "" } + return components.url + } + + // MARK: - External Token Provider + + private func fetchAccessTokenFromProvider( + context: OAuthConfiguration.AccessTokenProviderContext, + session: URLSession + ) async throws -> String? { + guard let provider = configuration.accessTokenProvider else { return nil } + guard let token = try await provider(context, session), !token.isEmpty else { return nil } + return token + } + + private func storeExternalAccessToken( + _ token: String, + requestedScopes: Set?, + authorizationServer: URL? + ) { + tokenStorage.save(OAuthAccessToken( + value: token, + tokenType: OAuthTokenType.bearer, + expiresAt: nil, + scopes: requestedScopes ?? [], + authorizationServer: authorizationServer, + refreshToken: nil + )) + } + + private func makeAccessTokenProviderContext( + statusCode: Int, + endpoint: URL, + challenge: OAuthBearerChallenge?, + metadata: OAuthProtectedResourceMetadata, + requestedScopes: Set?, + session: URLSession + ) async throws -> OAuthConfiguration.AccessTokenProviderContext { + let asMetadata = try await resolveAuthorizationServerMetadata( + metadata: metadata, session: session) + let resource = try canonicalResource(for: endpoint) + let authorizationServer = configuration.endpointOverrides.authorizationServerURL + ?? selectedAuthorizationServer + ?? metadata.authorizationServers.first + let tokenEndpoint = configuration.endpointOverrides.tokenEndpoint ?? asMetadata.tokenEndpoint + + return OAuthConfiguration.AccessTokenProviderContext( + statusCode: statusCode, + endpoint: endpoint, + resource: resource, + authorizationServer: authorizationServer, + authorizationEndpoint: asMetadata.authorizationEndpoint, + tokenEndpoint: tokenEndpoint, + registrationEndpoint: asMetadata.registrationEndpoint, + challengedScope: challenge?.scope, + scopesSupported: metadata.scopesSupported, + requestedScopes: requestedScopes + ) + } + + private func normalizedOperationKey(_ operationKey: String?) -> String { + guard let operationKey else { return "" } + let normalized = operationKey.trimmingCharacters(in: .whitespacesAndNewlines) + return normalized.isEmpty ? "" : normalized + } +} diff --git a/Sources/MCP/Base/Authorization/OAuthClientRegistrar.swift b/Sources/MCP/Base/Authorization/OAuthClientRegistrar.swift new file mode 100644 index 00000000..2bbdf133 --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthClientRegistrar.swift @@ -0,0 +1,186 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Internal protocol for OAuth dynamic client registration. +protocol OAuthClientRegistering: Sendable { + func register( + configuration: OAuthConfiguration, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws -> ( + response: OAuthClientRegistrationResponse, + updatedAuthentication: OAuthConfiguration.TokenEndpointAuthentication + )? +} + +/// Stateless OAuth dynamic client registration logic. +/// +/// Handles Client ID Metadata Document (CIMD) detection and RFC 7591 dynamic registration. +/// State tracking (`clientRegistrationAttempted`, `clientSecretExpiresAt`) is the caller's responsibility. +struct OAuthClientRegistrar: Sendable { + let urlValidator: OAuthURLValidator + + init(urlValidator: OAuthURLValidator) { + self.urlValidator = urlValidator + } + + /// Attempts to register the client, if applicable. + /// + /// Returns `nil` if registration is not needed: + /// - Credentials are already configured (not `.none`) + /// - CIMD is in use and the server supports it (pre-registered) + /// - No registration endpoint is available and no CIMD mismatch error + /// + /// Throws if registration was attempted but failed (4xx, 5xx, or unexpected response). + func register( + configuration: OAuthConfiguration, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws -> ( + response: OAuthClientRegistrationResponse, + updatedAuthentication: OAuthConfiguration.TokenEndpointAuthentication + )? { + guard case .none(let clientID) = configuration.authentication else { + return nil + } + + let hasClientIDMetadataDocument = isHTTPSURLWithPath(clientID) + let supportsClientIDMetadataDocument = asMetadata.clientIDMetadataDocumentSupported == true + + if supportsClientIDMetadataDocument, + clientIDLooksLikeURL(clientID), + !hasClientIDMetadataDocument + { + throw OAuthAuthorizationError.invalidClientIDMetadataURL(clientID) + } + + if hasClientIDMetadataDocument, supportsClientIDMetadataDocument { + return nil + } + + guard let registrationEndpoint = asMetadata.registrationEndpoint else { + if hasClientIDMetadataDocument && !supportsClientIDMetadataDocument { + throw OAuthAuthorizationError.cimdNotSupported(clientID: clientID) + } + return nil + } + try urlValidator.validateAuthorizationServer( + registrationEndpoint, context: "Client registration endpoint") + + var request = URLRequest(url: registrationEndpoint) + request.httpMethod = "POST" + request.setValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.contentType) + request.setValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.accept) + + let grantTypes: [String] + let responseTypes: [String] + switch configuration.grantType { + case .authorizationCode: + grantTypes = [OAuthGrantTypeValue.authorizationCode] + responseTypes = [OAuthParameterName.code] + case .clientCredentials: + grantTypes = [OAuthGrantTypeValue.clientCredentials] + responseTypes = [] + } + + var registrationPayload: [String: Any] = [ + "client_name": configuration.clientName, + "grant_types": grantTypes, + "token_endpoint_auth_method": configuration.authentication.methodName, + ] + if !responseTypes.isEmpty { + registrationPayload["response_types"] = responseTypes + } + if configuration.grantType == .authorizationCode { + registrationPayload["redirect_uris"] = [ + configuration.authorizationRedirectURI.absoluteString + ] + } + + let httpBody = try JSONSerialization.data(withJSONObject: registrationPayload) + + request.httpBody = httpBody + + let (data, response) = try await session.data(for: request) + if let httpResponse = response as? HTTPURLResponse, + (400..<500).contains(httpResponse.statusCode) + { + let oauthError = + (try? JSONDecoder().decode(OAuthTokenErrorResponse.self, from: data))?.error + throw OAuthAuthorizationError.tokenRequestFailed( + statusCode: httpResponse.statusCode, + oauthError: oauthError + ) + } + guard let httpResponse = response as? HTTPURLResponse, + (200..<300).contains(httpResponse.statusCode) + else { + let statusCode = (response as? HTTPURLResponse)?.statusCode ?? 0 + let oauthError = + (try? JSONDecoder().decode(OAuthTokenErrorResponse.self, from: data))?.error + throw OAuthAuthorizationError.tokenRequestFailed( + statusCode: statusCode, + oauthError: oauthError + ) + } + + let registration = try JSONDecoder().decode(OAuthClientRegistrationResponse.self, from: data) + + let updatedAuth = OAuthClientRegistrar.updatedAuthentication( + from: registration, current: configuration.authentication) + return (response: registration, updatedAuthentication: updatedAuth) + } + + /// Derives the updated token endpoint authentication from a registration response. + /// + /// Updates the client ID (and secret, if issued) while preserving the authentication method. + static func updatedAuthentication( + from registration: OAuthClientRegistrationResponse, + current: OAuthConfiguration.TokenEndpointAuthentication + ) -> OAuthConfiguration.TokenEndpointAuthentication { + switch current { + case .none: + return .none(clientID: registration.clientID) + case .clientSecretBasic(_, let currentSecret): + return .clientSecretBasic( + clientID: registration.clientID, + clientSecret: registration.clientSecret ?? currentSecret + ) + case .clientSecretPost(_, let currentSecret): + return .clientSecretPost( + clientID: registration.clientID, + clientSecret: registration.clientSecret ?? currentSecret + ) + case .privateKeyJWT(_, let factory): + return .privateKeyJWT(clientID: registration.clientID, assertionFactory: factory) + } + } + + // MARK: - CIMD Helpers + + private func isHTTPSURLWithPath(_ value: String) -> Bool { + guard let url = URL(string: value), + let components = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { + return false + } + return components.scheme?.lowercased() == OAuthURLScheme.https + && !components.path.isEmpty + && components.path != "/" + } + + private func clientIDLooksLikeURL(_ value: String) -> Bool { + guard let url = URL(string: value), + let scheme = url.scheme, + !scheme.isEmpty + else { + return false + } + return true + } +} + +extension OAuthClientRegistrar: OAuthClientRegistering {} diff --git a/Sources/MCP/Base/Authorization/OAuthConfiguration.swift b/Sources/MCP/Base/Authorization/OAuthConfiguration.swift new file mode 100644 index 00000000..a1c015f6 --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthConfiguration.swift @@ -0,0 +1,415 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if canImport(CryptoKit) + import CryptoKit +#endif + +/// Configuration for OAuth 2.1 authorization used by ``HTTPClientTransport``. +/// +/// Authorization is optional and disabled by default. Configure this type and pass it +/// to `HTTPClientTransport(oauth:)` to enable automatic Bearer token acquisition for HTTP +/// transports. +/// +/// Supports both `authorization_code` (interactive, browser-based) and `client_credentials` +/// (machine-to-machine) grant types via the ``grantType`` property. +public struct OAuthConfiguration: Sendable { + /// The OAuth 2.1 grant type to use for token acquisition. + public enum GrantType: Sendable { + /// OAuth 2.1 authorization_code flow with PKCE. + case authorizationCode + + /// OAuth 2.1 client_credentials flow. + case clientCredentials + } + + /// How the client authenticates to the OAuth token endpoint. + public enum TokenEndpointAuthentication: Sendable, Equatable { + /// `client_secret_basic` authentication using the Authorization header. + case clientSecretBasic(clientID: String, clientSecret: String) + + /// `client_secret_post` authentication using form parameters. + case clientSecretPost(clientID: String, clientSecret: String) + + /// Public client authentication (`token_endpoint_auth_method=none`). + case none(clientID: String) + + /// `private_key_jwt` authentication. + /// + /// The built-in ``OAuthConfiguration/makePrivateKeyJWTAssertion(clientID:tokenEndpoint:privateKeyPEM:signingAlgorithm:audience:issuedAt:expiresIn:)`` + /// helper generates ES256 (P-256 ECDSA) assertions only. To use other algorithms + /// (e.g., RS256, ES384), provide a custom ``OAuthConfiguration/JWTAssertionFactory`` closure. + case privateKeyJWT(clientID: String, assertionFactory: JWTAssertionFactory) + + var clientID: String { + switch self { + case .clientSecretBasic(let clientID, _), .clientSecretPost(let clientID, _), + .none(let clientID), .privateKeyJWT(let clientID, _): + return clientID + } + } + + /// The token endpoint auth method name per RFC 7591. + var methodName: String { + switch self { + case .clientSecretBasic: return OAuthTokenEndpointAuthMethod.clientSecretBasic + case .clientSecretPost: return OAuthTokenEndpointAuthMethod.clientSecretPost + case .none: return OAuthTokenEndpointAuthMethod.none + case .privateKeyJWT: return OAuthTokenEndpointAuthMethod.privateKeyJWT + } + } + + func apply( + to request: inout URLRequest, + bodyParameters: inout [String: String], + tokenEndpoint: URL + ) async throws { + switch self { + case .clientSecretBasic(let clientID, let clientSecret): + let basic = Data("\(clientID):\(clientSecret)".utf8).base64EncodedString() + request.setValue("Basic \(basic)", forHTTPHeaderField: HTTPHeaderName.authorization) + case .clientSecretPost(let clientID, let clientSecret): + bodyParameters[OAuthParameterName.clientID] = clientID + bodyParameters[OAuthParameterName.clientSecret] = clientSecret + case .none(let clientID): + bodyParameters[OAuthParameterName.clientID] = clientID + case .privateKeyJWT(let clientID, let assertionFactory): + bodyParameters[OAuthParameterName.clientID] = clientID + bodyParameters[OAuthParameterName.clientAssertionType] = + OAuthClientAssertionType.jwtBearer + bodyParameters[OAuthParameterName.clientAssertion] = + try await assertionFactory(tokenEndpoint, clientID) + } + } + } + + /// Closure used to generate a `private_key_jwt` assertion. + public typealias JWTAssertionFactory = @Sendable (_ tokenEndpoint: URL, _ clientID: String) + async throws -> String + + /// Supported signing algorithms for SDK-generated `private_key_jwt` assertions. + public enum PrivateKeyJWTSigningAlgorithm: String, Sendable { + case ES256 + } + + /// Errors thrown while creating SDK-generated `private_key_jwt` assertions. + public enum PrivateKeyJWTAssertionError: LocalizedError, Sendable { + case invalidLifetime(TimeInterval) + case cryptographyUnavailable + + public var errorDescription: String? { + switch self { + case .invalidLifetime(let lifetime): + return "private_key_jwt assertion lifetime must be greater than zero seconds, got \(lifetime)" + case .cryptographyUnavailable: + return "private_key_jwt assertion signing requires CryptoKit support" + } + } + } + + /// Creates a signed `private_key_jwt` client assertion (RFC 7523). + /// + /// Use this helper to build a `JWTAssertionFactory` closure when configuring + /// ``TokenEndpointAuthentication/privateKeyJWT(clientID:assertionFactory:)``: + /// + /// ```swift + /// let factory: OAuthConfiguration.JWTAssertionFactory = { tokenEndpoint, clientID in + /// try OAuthConfiguration.makePrivateKeyJWTAssertion( + /// clientID: clientID, + /// tokenEndpoint: tokenEndpoint, + /// privateKeyPEM: myPEMString + /// ) + /// } + /// ``` + /// + /// The assertion is a compact-serialized JWS signed with the specified private key. + /// Only ES256 (P-256 ECDSA) is supported by this helper; the algorithm requires CryptoKit. + /// To use other algorithms, supply a custom ``JWTAssertionFactory`` closure directly to + /// ``TokenEndpointAuthentication/privateKeyJWT(clientID:assertionFactory:)``. + /// + /// - Parameters: + /// - clientID: The OAuth client identifier, used as both `iss` and `sub` claims. + /// - tokenEndpoint: Token endpoint URL, used as the default `aud` claim. + /// - privateKeyPEM: PEM-encoded EC private key for signing. + /// - signingAlgorithm: The JWS signing algorithm. Currently only `.ES256` is supported. + /// - audience: Explicit `aud` claim override. Defaults to `tokenEndpoint.absoluteString`. + /// - issuedAt: The `iat` claim. Defaults to `Date()`. + /// - expiresIn: Lifetime of the assertion in seconds. Defaults to 300 (5 minutes). Must be > 0. + /// - Returns: A compact-serialized JWT string (`header.payload.signature`). + /// - Throws: ``PrivateKeyJWTAssertionError/invalidLifetime(_:)`` if `expiresIn` ≤ 0, + /// or ``PrivateKeyJWTAssertionError/cryptographyUnavailable`` on platforms without CryptoKit. + public static func makePrivateKeyJWTAssertion( + clientID: String, + tokenEndpoint: URL, + privateKeyPEM: String, + signingAlgorithm: PrivateKeyJWTSigningAlgorithm = .ES256, + audience: String? = nil, + issuedAt: Date = Date(), + expiresIn: TimeInterval = 300 + ) throws -> String { + guard expiresIn > 0 else { + throw PrivateKeyJWTAssertionError.invalidLifetime(expiresIn) + } + + let header = try JSONSerialization.data(withJSONObject: [ + JWTClaimName.algorithm: signingAlgorithm.rawValue, + JWTClaimName.type: JWTClaimName.typeValue, + ]) + + let issuedAtUnix = Int(issuedAt.timeIntervalSince1970) + let lifetimeSeconds = max(1, Int(expiresIn.rounded(.down))) + let payload = try JSONSerialization.data(withJSONObject: [ + JWTClaimName.issuer: clientID, + JWTClaimName.subject: clientID, + JWTClaimName.audience: audience ?? tokenEndpoint.absoluteString, + JWTClaimName.issuedAt: issuedAtUnix, + JWTClaimName.expiration: issuedAtUnix + lifetimeSeconds, + JWTClaimName.jwtID: UUID().uuidString, + ]) + + let signingInput = "\(header.base64URLEncodedString()).\(payload.base64URLEncodedString())" + + #if canImport(CryptoKit) + switch signingAlgorithm { + case .ES256: + let privateKey = try P256.Signing.PrivateKey(pemRepresentation: privateKeyPEM) + let signature = try privateKey.signature(for: Data(signingInput.utf8)).rawRepresentation + return "\(signingInput).\(signature.base64URLEncodedString())" + } + #else + throw PrivateKeyJWTAssertionError.cryptographyUnavailable + #endif + } + + static func defaultAuthorizationRedirectURI() -> URL { + let port = Int.random(in: 49152...65535) + return URL(string: "http://\(OAuthLoopbackHost.ipv4):\(port)/callback")! + } + + // MARK: - Retry Policy + + /// Controls retry behavior for authorization challenges. + public struct RetryPolicy: Sendable { + /// Maximum number of authentication retries for a single MCP request. + public let maxAuthorizationAttempts: Int + + /// Maximum number of scope step-up attempts for a resource and operation. + public let maxScopeUpgradeAttempts: Int + + /// Creates a retry policy. + /// + /// Both values are clamped to a minimum of `1` to prevent infinite loops + /// or zero-retry configurations. + /// + /// - Parameters: + /// - maxAuthorizationAttempts: Maximum authorization retries per request. Defaults to 3. + /// - maxScopeUpgradeAttempts: Maximum scope step-up attempts per resource+operation. Defaults to 2. + public init( + maxAuthorizationAttempts: Int = 3, + maxScopeUpgradeAttempts: Int = 2 + ) { + self.maxAuthorizationAttempts = max(1, maxAuthorizationAttempts) + self.maxScopeUpgradeAttempts = max(1, maxScopeUpgradeAttempts) + } + + public static let `default` = RetryPolicy() + } + + // MARK: - Endpoint Overrides + + /// Optional endpoint overrides for discovery. + public struct EndpointOverrides: Sendable { + /// Optional override for the protected resource metadata URL. + public let protectedResourceMetadataURL: URL? + + /// Optional override for the authorization server issuer URL. + public let authorizationServerURL: URL? + + /// Optional override for the token endpoint. + public let tokenEndpoint: URL? + + /// Optional override for the resource indicator used in token requests. + public let resource: URL? + + public init( + protectedResourceMetadataURL: URL? = nil, + authorizationServerURL: URL? = nil, + tokenEndpoint: URL? = nil, + resource: URL? = nil + ) { + self.protectedResourceMetadataURL = protectedResourceMetadataURL + self.authorizationServerURL = authorizationServerURL + self.tokenEndpoint = tokenEndpoint + self.resource = resource + } + + public static let none = EndpointOverrides() + } + + // MARK: - Access Token Provider + + /// Context supplied to ``AccessTokenProvider`` after SDK discovery is complete. + public struct AccessTokenProviderContext: Sendable { + /// HTTP status code that triggered authorization handling (typically 401 or 403). + public let statusCode: Int + /// Target MCP endpoint URL. + public let endpoint: URL + /// Canonical RFC8707 resource URI for token audience binding. + public let resource: URL + /// Selected authorization server issuer URL, if discovered. + public let authorizationServer: URL? + /// Authorization endpoint from AS metadata, if available. + public let authorizationEndpoint: URL? + /// Token endpoint from AS metadata (or configuration override), if available. + public let tokenEndpoint: URL? + /// Dynamic registration endpoint from AS metadata, if available. + public let registrationEndpoint: URL? + /// `scope` value from the latest challenge header, when present. + public let challengedScope: String? + /// `scopes_supported` from protected resource metadata, when present. + public let scopesSupported: [String]? + /// Scope set selected by SDK for the pending authorization attempt. + public let requestedScopes: Set? + + public init( + statusCode: Int, + endpoint: URL, + resource: URL, + authorizationServer: URL?, + authorizationEndpoint: URL?, + tokenEndpoint: URL?, + registrationEndpoint: URL?, + challengedScope: String?, + scopesSupported: [String]?, + requestedScopes: Set? + ) { + self.statusCode = statusCode + self.endpoint = endpoint + self.resource = resource + self.authorizationServer = authorizationServer + self.authorizationEndpoint = authorizationEndpoint + self.tokenEndpoint = tokenEndpoint + self.registrationEndpoint = registrationEndpoint + self.challengedScope = challengedScope + self.scopesSupported = scopesSupported + self.requestedScopes = requestedScopes + } + } + + /// Optional provider for externally acquired access tokens. + public typealias AccessTokenProvider = @Sendable ( + _ context: AccessTokenProviderContext, + _ session: URLSession + ) async throws -> String? + + // MARK: - Properties + + /// The grant type used for token acquisition. + public let grantType: GrantType + + /// The configured token endpoint authentication method. + public var authentication: TokenEndpointAuthentication + + /// Controls retry behavior for authorization challenges. + public let retryPolicy: RetryPolicy + + /// Optional endpoint overrides for discovery. + public let endpointOverrides: EndpointOverrides + + /// Package-scoped compatibility option for local environments. + package var allowLoopbackHTTPAuthorizationServerEndpoints: Bool + + /// Redirect URI used for authorization requests. + public let authorizationRedirectURI: URL + + /// Additional form fields to include in token requests. + public let additionalTokenRequestParameters: [String: String] + + /// The `client_name` sent during dynamic client registration (RFC 7591). + /// Defaults to `"mcp-swift-sdk"`. Override with your application's name. + public let clientName: String + + /// Optional provider for externally acquired access tokens. + public let accessTokenProvider: AccessTokenProvider? + + /// Optional delegate for browser-based authorization code flows. + public let authorizationDelegate: (any OAuthAuthorizationDelegate)? + + /// How many seconds before token expiry ``OAuthAuthorizer`` proactively refreshes a token + /// when `prepareAuthorization(for:session:)` is called. + /// + /// Set to `0` to disable proactive refresh. Defaults to 60 seconds. + /// Must be greater than ``OAuthTokenExpirySkew/defaultSeconds`` (30 s) to have any effect, + /// since tokens within the default skew window are already treated as expired. + public let proactiveRefreshWindowSeconds: TimeInterval + + /// Creates an OAuth configuration. + /// + /// - Parameters: + /// - grantType: The OAuth 2.1 grant type. Defaults to `.clientCredentials`. + /// - authentication: How the client authenticates to the token endpoint. **Required.** + /// - retryPolicy: Controls how many retries are allowed for authorization challenges. + /// - endpointOverrides: Optional URL overrides that bypass automatic discovery. + /// - authorizationRedirectURI: Redirect URI for the `authorization_code` flow. + /// Defaults to a random loopback URI (`http://127.0.0.1:/callback`). + /// - clientName: The `client_name` sent during dynamic client registration. Defaults to `"mcp-swift-sdk"`. + /// - additionalTokenRequestParameters: Extra form fields appended to every token request. + /// - accessTokenProvider: Optional closure invoked after discovery, allowing the host app + /// to supply an externally acquired token (e.g., from a system credential store). + /// - authorizationDelegate: Optional delegate that presents the authorization URL to the + /// user for interactive `authorization_code` flows. + /// - proactiveRefreshWindowSeconds: Seconds before expiry at which a token is proactively + /// refreshed. Defaults to 60. Set to 0 to disable proactive refresh. + public init( + grantType: GrantType = .clientCredentials, + authentication: TokenEndpointAuthentication, + retryPolicy: RetryPolicy = .default, + endpointOverrides: EndpointOverrides = .none, + authorizationRedirectURI: URL? = nil, + clientName: String = "mcp-swift-sdk", + additionalTokenRequestParameters: [String: String] = [:], + accessTokenProvider: AccessTokenProvider? = nil, + authorizationDelegate: (any OAuthAuthorizationDelegate)? = nil, + proactiveRefreshWindowSeconds: TimeInterval = 60 + ) { + self.grantType = grantType + self.authentication = authentication + self.retryPolicy = retryPolicy + self.endpointOverrides = endpointOverrides + self.allowLoopbackHTTPAuthorizationServerEndpoints = false + self.authorizationRedirectURI = + authorizationRedirectURI ?? Self.defaultAuthorizationRedirectURI() + self.clientName = clientName + self.additionalTokenRequestParameters = additionalTokenRequestParameters + self.accessTokenProvider = accessTokenProvider + self.authorizationDelegate = authorizationDelegate + self.proactiveRefreshWindowSeconds = proactiveRefreshWindowSeconds + } +} + +extension OAuthConfiguration.TokenEndpointAuthentication { + public static func == (lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.clientSecretBasic(let li, let ls), .clientSecretBasic(let ri, let rs)): + return li == ri && ls == rs + case (.clientSecretPost(let li, let ls), .clientSecretPost(let ri, let rs)): + return li == ri && ls == rs + case (.none(let li), .none(let ri)): + return li == ri + case (.privateKeyJWT(let li, _), .privateKeyJWT(let ri, _)): + return li == ri + default: + return false + } + } +} + +/// Delegate that handles user-facing authorization steps for the authorization_code flow. +public protocol OAuthAuthorizationDelegate: Sendable { + /// Presents the authorization URL to the user and returns the redirect URL containing + /// the authorization code. + func presentAuthorizationURL(_ url: URL) async throws -> URL +} diff --git a/Sources/MCP/Base/Authorization/OAuthConstants.swift b/Sources/MCP/Base/Authorization/OAuthConstants.swift new file mode 100644 index 00000000..73cde0f1 --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthConstants.swift @@ -0,0 +1,139 @@ +import Foundation + +// MARK: - OAuth Grant Types + +enum OAuthGrantTypeValue { + static let clientCredentials = "client_credentials" + static let authorizationCode = "authorization_code" + static let refreshToken = "refresh_token" +} + +// MARK: - OAuth Parameter Names + +enum OAuthParameterName { + static let grantType = "grant_type" + static let resource = "resource" + static let scope = "scope" + static let code = "code" + static let codeVerifier = "code_verifier" + static let codeChallenge = "code_challenge" + static let codeChallengeMethod = "code_challenge_method" + static let redirectURI = "redirect_uri" + static let responseType = "response_type" + static let clientID = "client_id" + static let clientSecret = "client_secret" + static let clientAssertion = "client_assertion" + static let clientAssertionType = "client_assertion_type" + static let refreshToken = "refresh_token" + static let state = "state" +} + +// MARK: - OAuth Well-Known Paths + +enum OAuthWellKnownPath { + static let protectedResource = "/.well-known/oauth-protected-resource" + static let authorizationServer = "/.well-known/oauth-authorization-server" + static let openIDConfiguration = "/.well-known/openid-configuration" +} + +// MARK: - OAuth Token Type + +enum OAuthTokenType { + static let bearer = "Bearer" +} + +// MARK: - OAuth Code Challenge Method + +enum OAuthCodeChallengeMethod { + static let s256 = "S256" +} + +// MARK: - OAuth Token Endpoint Auth Method + +enum OAuthTokenEndpointAuthMethod { + static let clientSecretBasic = "client_secret_basic" + static let clientSecretPost = "client_secret_post" + static let none = "none" + static let privateKeyJWT = "private_key_jwt" +} + +// MARK: - URL Scheme + +enum OAuthURLScheme { + static let http = "http" + static let https = "https" +} + +// MARK: - Default Ports + +enum OAuthDefaultPort { + static let http = 80 + static let https = 443 +} + +// MARK: - Loopback Hosts + +enum OAuthLoopbackHost { + static let localhost = "localhost" + static let ipv4 = "127.0.0.1" + static let ipv6 = "::1" + + static func isLoopback(_ host: String) -> Bool { + host == localhost || host == ipv4 || host == ipv6 + } +} + +// MARK: - Token Expiry Skew + +/// Clock skew tolerance applied when checking token expiry. +/// +/// ``OAuthAccessToken/isExpired(now:skewSeconds:)`` treats a token as expired +/// when `now + skewSeconds >= expiresAt`, giving the client a safety margin to +/// refresh tokens before they actually expire on the server. +public enum OAuthTokenExpirySkew { + /// Default clock skew buffer: 30 seconds. + public static let defaultSeconds: TimeInterval = 30 +} + +// MARK: - JWT Claim Names + +enum JWTClaimName { + static let algorithm = "alg" + static let type = "typ" + static let typeValue = "JWT" + static let issuer = "iss" + static let subject = "sub" + static let audience = "aud" + static let issuedAt = "iat" + static let expiration = "exp" + static let jwtID = "jti" +} + +// MARK: - Client Assertion Type + +enum OAuthClientAssertionType { + static let jwtBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +} + +// MARK: - HTTPHeaderName Extensions + +extension HTTPHeaderName { + static let location = "Location" +} + +// MARK: - ContentType Extensions + +extension ContentType { + static let formURLEncoded = "application/x-www-form-urlencoded" +} + +// MARK: - Data Base64URL Extension + +extension Data { + func base64URLEncodedString() -> String { + base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + } +} diff --git a/Sources/MCP/Base/Authorization/OAuthDiscovery.swift b/Sources/MCP/Base/Authorization/OAuthDiscovery.swift new file mode 100644 index 00000000..332a898e --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthDiscovery.swift @@ -0,0 +1,361 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +// MARK: - Scope Selection Protocol + +/// Determines which OAuth scopes to request during token acquisition. +/// +/// ``OAuthAuthorizer`` uses this protocol to translate raw scope strings from +/// `WWW-Authenticate` challenges and Protected Resource Metadata into the +/// scope set passed to the token endpoint. +/// +/// Override the default ``DefaultOAuthScopeSelector`` to apply custom scope +/// filtering or transformation logic. +public protocol OAuthScopeSelecting { + /// Selects the scope set to request for a token. + /// + /// Priority order (highest first): + /// 1. `challengeScope` — the `scope` parameter from the `WWW-Authenticate` header. + /// 2. `scopesSupported` — the `scopes_supported` array from Protected Resource Metadata. + /// 3. `nil` — no scope restriction. + /// + /// - Parameters: + /// - challengeScope: Space-separated scope string from the Bearer challenge, or `nil`. + /// - scopesSupported: Array of supported scopes from the resource metadata, or `nil`. + /// - Returns: The set of scopes to request, or `nil` to omit the `scope` parameter entirely. + func selectScopes(challengeScope: String?, scopesSupported: [String]?) -> Set? + + /// Parses a space-separated OAuth scope string into individual scope tokens. + /// - Parameter scope: A scope string such as `"read write"`. + /// - Returns: A set of individual scope strings with whitespace-only tokens removed. + func parseScopeString(_ scope: String) -> Set + + /// Serializes a set of scopes into a space-separated string suitable for the `scope` parameter. + /// - Parameter scopes: The scope set to serialize. + /// - Returns: A sorted, space-separated string, or `nil` if the set is empty. + func serialize(_ scopes: Set) -> String? +} + +// MARK: - Default Scope Selector + +/// Default ``OAuthScopeSelecting`` implementation. +/// +/// Selects scopes in priority order: challenge scope > `scopes_supported` > `nil`. +/// Serializes scopes sorted alphabetically to produce deterministic `scope` parameters. +public struct DefaultOAuthScopeSelector: OAuthScopeSelecting { + public init() {} + + public func selectScopes(challengeScope: String?, scopesSupported: [String]?) -> Set? { + if let challengeScope { + let parsed = parseScopeString(challengeScope) + return parsed.isEmpty ? nil : parsed + } + + if let scopesSupported { + let parsed = Set(scopesSupported.filter { !$0.trimmingCharacters(in: .whitespaces).isEmpty }) + return parsed.isEmpty ? nil : parsed + } + + return nil + } + + public func parseScopeString(_ scope: String) -> Set { + Set( + scope + .split(whereSeparator: { $0.isWhitespace }) + .map(String.init) + .filter { !$0.isEmpty } + ) + } + + public func serialize(_ scopes: Set) -> String? { + guard !scopes.isEmpty else { return nil } + return scopes.sorted().joined(separator: " ") + } +} + +// MARK: - Metadata Discovery Protocol + +/// Builds discovery URLs and normalizes resource identifiers for OAuth metadata discovery. +/// +/// ``OAuthAuthorizer`` uses this protocol to construct the candidate URL list for both +/// Protected Resource Metadata (RFC 9728) and Authorization Server Metadata (RFC 8414 / OIDC) +/// discovery, and to normalize resource URIs for RFC 8707 audience binding. +/// +/// Override the default ``DefaultOAuthMetadataDiscovery`` to customise discovery URL +/// construction or resource-matching logic. +public protocol OAuthMetadataDiscovering: Sendable { + /// Returns candidate URLs for Protected Resource Metadata, ordered by priority. + /// + /// ``OAuthAuthorizer`` tries each URL in order and uses the first successful response. + /// + /// - Parameter endpoint: The MCP endpoint URL. + /// - Returns: An ordered list of discovery URLs (typically `/.well-known/oauth-protected-resource` + /// variants). + func protectedResourceMetadataURLs(for endpoint: URL) -> [URL] + + /// Returns candidate URLs for Authorization Server Metadata, ordered by priority. + /// + /// Covers RFC 8414 (`/.well-known/oauth-authorization-server`) and + /// OIDC Discovery 1.0 (`/.well-known/openid-configuration`), including + /// path-inserted variants for issuers with non-root paths. + /// + /// - Parameter issuer: The authorization server issuer URL. + /// - Returns: An ordered list of metadata discovery URLs. + func authorizationServerMetadataURLs(for issuer: URL) -> [URL] + + /// Derives the canonical RFC 8707 resource URI from an endpoint URL. + /// + /// The canonical form strips the query string, fragment, and trailing slash + /// while preserving the scheme, host, port, and path. + /// + /// - Parameter endpoint: The MCP endpoint URL. + /// - Returns: The canonical resource URI. + /// - Throws: ``OAuthAuthorizationError/invalidResourceURI(_:)`` if the URL does not + /// satisfy the HTTPS-or-loopback-HTTP requirement. + func canonicalResourceURI(from endpoint: URL) throws -> URL + + /// Derives a fallback authorization server issuer URL from an endpoint URL. + /// + /// Used when no authorization server is listed in Protected Resource Metadata. + /// Typically returns the scheme+host+port of the endpoint with an empty path. + /// + /// - Parameter endpoint: The MCP endpoint URL. + /// - Returns: A candidate issuer URL. + /// - Throws: ``OAuthAuthorizationError/invalidResourceURI(_:)`` if the URL is invalid. + func authorizationServerFallbackIssuer(from endpoint: URL) throws -> URL + + /// Returns `true` if `resource` is a prefix of `endpoint` in the URL hierarchy. + /// + /// A resource matches an endpoint when the two share the same scheme, host, and port, + /// and the endpoint path starts with the resource path. + /// + /// - Parameters: + /// - resource: The canonical resource URI from Protected Resource Metadata. + /// - endpoint: The canonical endpoint URI being requested. + /// - Returns: `true` if the endpoint falls within the resource's scope. + func protectedResourceMatches(resource: URL, endpoint: URL) -> Bool +} + +// MARK: - Default Metadata Discovery + +/// Default ``OAuthMetadataDiscovering`` implementation following RFC 9728 and RFC 8414. +/// +/// - Builds `/.well-known/oauth-protected-resource` URLs with and without the endpoint path suffix. +/// - Builds RFC 8414 and OIDC discovery URLs with path-inserted and path-appended variants. +/// - Canonicalises resource URIs by stripping query, fragment, and trailing slash. +/// - Matches resources using scheme/host/port equality and path prefix rules. +public struct DefaultOAuthMetadataDiscovery: OAuthMetadataDiscovering { + public init() {} + + public func protectedResourceMetadataURLs(for endpoint: URL) -> [URL] { + guard var components = URLComponents(url: endpoint, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + Self.isSecureOAuthScheme(scheme: scheme, host: host) + else { + return [] + } + + components.query = nil + components.fragment = nil + + let endpointPath = components.path + let normalizedPath = endpointPath == "/" ? "" : endpointPath + + var urls: [URL] = [] + + var pathSpecific = components + pathSpecific.path = "\(OAuthWellKnownPath.protectedResource)\(normalizedPath)" + if let url = pathSpecific.url { + urls.append(url) + } + + var root = components + root.path = OAuthWellKnownPath.protectedResource + if let url = root.url { + urls.append(url) + } + + return urls + } + + public func authorizationServerMetadataURLs(for issuer: URL) -> [URL] { + guard var components = URLComponents(url: issuer, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + Self.isSecureOAuthScheme(scheme: scheme, host: host) + else { + return [] + } + + components.query = nil + components.fragment = nil + + let path = components.path.trimmingCharacters(in: CharacterSet(charactersIn: "/")) + let hasPath = !path.isEmpty + + var urls: [URL] = [] + + if hasPath { + var oauthInserted = components + oauthInserted.path = "\(OAuthWellKnownPath.authorizationServer)/\(path)" + if let url = oauthInserted.url { + urls.append(url) + } + + var oidcInserted = components + oidcInserted.path = "\(OAuthWellKnownPath.openIDConfiguration)/\(path)" + if let url = oidcInserted.url { + urls.append(url) + } + + var oidcAppended = components + oidcAppended.path = "/\(path)\(OAuthWellKnownPath.openIDConfiguration)" + if let url = oidcAppended.url { + urls.append(url) + } + } else { + var oauth = components + oauth.path = OAuthWellKnownPath.authorizationServer + if let url = oauth.url { + urls.append(url) + } + + var oidc = components + oidc.path = OAuthWellKnownPath.openIDConfiguration + if let url = oidc.url { + urls.append(url) + } + } + + return urls + } + + public func canonicalResourceURI(from endpoint: URL) throws -> URL { + guard var components = URLComponents(url: endpoint, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + Self.isSecureOAuthScheme(scheme: scheme, host: host) + else { + throw OAuthAuthorizationError.invalidResourceURI( + "Resource URI must use https or loopback http" + ) + } + + if components.fragment != nil { + throw OAuthAuthorizationError.invalidResourceURI("Resource URI must not contain a fragment") + } + + components.scheme = scheme + components.host = host + components.query = nil + components.fragment = nil + + if components.path == "/" { + components.path = "" + } + + guard let url = components.url else { + throw OAuthAuthorizationError.invalidResourceURI("Failed to normalize resource URI") + } + + return url + } + + public func authorizationServerFallbackIssuer(from endpoint: URL) throws -> URL { + guard var components = URLComponents(url: endpoint, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + Self.isSecureOAuthScheme(scheme: scheme, host: host) + else { + throw OAuthAuthorizationError.invalidResourceURI( + "Resource URI must use https or loopback http" + ) + } + + components.scheme = scheme + components.host = host + components.path = "" + components.query = nil + components.fragment = nil + + guard let url = components.url else { + throw OAuthAuthorizationError.invalidResourceURI("Failed to derive issuer URI") + } + return url + } + + public func protectedResourceMatches(resource: URL, endpoint: URL) -> Bool { + guard let resourceComponents = URLComponents( + url: resource, + resolvingAgainstBaseURL: false + ), + let endpointComponents = URLComponents( + url: endpoint, + resolvingAgainstBaseURL: false + ) + else { + return false + } + + let resourceScheme = resourceComponents.scheme?.lowercased() + let endpointScheme = endpointComponents.scheme?.lowercased() + let resourceHost = resourceComponents.host?.lowercased() + let endpointHost = endpointComponents.host?.lowercased() + let resourcePort = resourceComponents.port ?? Self.defaultPort(for: resourceScheme) + let endpointPort = endpointComponents.port ?? Self.defaultPort(for: endpointScheme) + + guard resourceScheme == endpointScheme, + resourceHost == endpointHost, + resourcePort == endpointPort + else { + return false + } + + let resourcePath = Self.normalizedResourcePath(resourceComponents.path) + let endpointPath = Self.normalizedResourcePath(endpointComponents.path) + if resourcePath.isEmpty { + return true + } + if endpointPath == resourcePath { + return true + } + return endpointPath.hasPrefix(resourcePath + "/") + } + + private static func normalizedResourcePath(_ rawPath: String) -> String { + if rawPath.isEmpty || rawPath == "/" { + return "" + } + if rawPath.count > 1 && rawPath.hasSuffix("/") { + return String(rawPath.dropLast()) + } + return rawPath + } + + private static func defaultPort(for scheme: String?) -> Int? { + switch scheme?.lowercased() { + case OAuthURLScheme.http: + return OAuthDefaultPort.http + case OAuthURLScheme.https: + return OAuthDefaultPort.https + default: + return nil + } + } + + static func isSecureOAuthScheme(scheme: String, host: String) -> Bool { + if scheme == OAuthURLScheme.https { + return true + } + if scheme == OAuthURLScheme.http { + return OAuthLoopbackHost.isLoopback(host) + } + return false + } +} + diff --git a/Sources/MCP/Base/Authorization/OAuthDiscoveryClient.swift b/Sources/MCP/Base/Authorization/OAuthDiscoveryClient.swift new file mode 100644 index 00000000..55bdb09d --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthDiscoveryClient.swift @@ -0,0 +1,118 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Internal protocol for fetching OAuth discovery metadata. +protocol OAuthDiscoveryFetching: Sendable { + var metadataDiscovery: any OAuthMetadataDiscovering { get } + func fetchProtectedResourceMetadata(candidates: [URL], session: URLSession) async throws -> OAuthProtectedResourceMetadata + func fetchAuthorizationServerMetadata(candidates: [URL], session: URLSession) async throws -> (server: URL, metadata: OAuthAuthorizationServerMetadata) +} + +/// Stateless OAuth metadata fetcher. +/// +/// Fetches Protected Resource Metadata (RFC 9728) and Authorization Server Metadata +/// (RFC 8414 / OIDC Discovery 1.0) from ordered candidate URL lists. +/// Cache management is the caller's responsibility. +struct OAuthDiscoveryClient: Sendable { + let metadataDiscovery: any OAuthMetadataDiscovering + let urlValidator: OAuthURLValidator + + init( + metadataDiscovery: any OAuthMetadataDiscovering, + urlValidator: OAuthURLValidator + ) { + self.metadataDiscovery = metadataDiscovery + self.urlValidator = urlValidator + } + + /// Fetches Protected Resource Metadata from the first candidate that returns a valid response. + func fetchProtectedResourceMetadata( + candidates: [URL], + session: URLSession + ) async throws -> OAuthProtectedResourceMetadata { + let decoder = JSONDecoder() + for url in candidates { + var request = URLRequest(url: url) + request.httpMethod = "GET" + request.setValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.accept) + + do { + let (data, response) = try await session.data(for: request) + guard let httpResponse = response as? HTTPURLResponse, + (200..<300).contains(httpResponse.statusCode) + else { + continue + } + + let metadata = try decoder.decode(OAuthProtectedResourceMetadata.self, from: data) + guard !metadata.authorizationServers.isEmpty else { continue } + return metadata + } catch let error as OAuthAuthorizationError { + throw error + } catch { + continue + } + } + throw OAuthAuthorizationError.metadataDiscoveryFailed + } + + /// Fetches Authorization Server Metadata from the first candidate that returns a valid response. + func fetchAuthorizationServerMetadata( + candidates: [URL], + session: URLSession + ) async throws -> (server: URL, metadata: OAuthAuthorizationServerMetadata) { + let decoder = JSONDecoder() + for candidateServer in candidates { + guard (try? urlValidator.validateAuthorizationServer( + candidateServer, context: "Authorization server issuer")) != nil + else { + continue + } + if let host = URLComponents(url: candidateServer, resolvingAgainstBaseURL: false)? + .host?.lowercased(), urlValidator.isPrivateIPHost(host) + { + continue + } + + for metadataURL in metadataDiscovery.authorizationServerMetadataURLs( + for: candidateServer) + { + var request = URLRequest(url: metadataURL) + request.httpMethod = "GET" + request.setValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.accept) + + do { + let (data, response) = try await session.data(for: request) + guard let httpResponse = response as? HTTPURLResponse, + (200..<300).contains(httpResponse.statusCode) + else { + continue + } + + let asMetadata = try decoder.decode( + OAuthAuthorizationServerMetadata.self, from: data) + + // RFC 8414 §3: issuer field must match the candidate server URL. + // Absent issuer is tolerated (some servers omit it). + if let metadataIssuer = asMetadata.issuer { + guard metadataIssuer.absoluteString.lowercased() + == candidateServer.absoluteString.lowercased() + else { + continue + } + } + + return (server: candidateServer, metadata: asMetadata) + } catch { + continue + } + } + } + throw OAuthAuthorizationError.authorizationServerMetadataDiscoveryFailed + } +} + +extension OAuthDiscoveryClient: OAuthDiscoveryFetching {} diff --git a/Sources/MCP/Base/Authorization/OAuthErrors.swift b/Sources/MCP/Base/Authorization/OAuthErrors.swift new file mode 100644 index 00000000..2f8249ce --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthErrors.swift @@ -0,0 +1,203 @@ +import Foundation + +/// Errors thrown during OAuth 2.1 authorization by ``OAuthAuthorizer``. +/// +/// These errors surface when the authorizer is unable to complete the authorization flow, +/// either due to discovery failures, token exchange problems, security policy violations, +/// or authorization code flow issues. +public enum OAuthAuthorizationError: LocalizedError { + /// No authorization server URL could be found in the Protected Resource Metadata. + case missingAuthorizationServer + + /// All Protected Resource Metadata discovery candidates returned errors or invalid documents. + case metadataDiscoveryFailed + + /// All Authorization Server Metadata discovery candidates (RFC 8414 / OIDC) returned errors. + case authorizationServerMetadataDiscoveryFailed + + /// The authorization server metadata does not include a `token_endpoint`. + case tokenEndpointMissing + + /// The token endpoint returned a non-2xx HTTP response. + /// + /// - Parameters: + /// - statusCode: HTTP status code from the token endpoint. + /// - oauthError: The `error` field from the OAuth error response body, if present. + case tokenRequestFailed(statusCode: Int, oauthError: String?) + + /// The token response body is missing `access_token`, has an empty token, or specifies + /// a non-Bearer token type. + case tokenResponseInvalid + + /// A URL that must be a valid resource identifier (RFC 8707) failed validation. + case invalidResourceURI(String) + + /// The configured client ID looks like a URL but is not a valid HTTPS URL with a path, + /// which is required for Client ID Metadata Documents. + case invalidClientIDMetadataURL(String) + + /// The `resource` field in the Protected Resource Metadata does not match the requested endpoint. + /// + /// - Parameters: + /// - expected: The canonical URI derived from the endpoint. + /// - actual: The canonical URI derived from the metadata's `resource` field. + case protectedResourceMismatch(expected: String, actual: String) + + /// The authorization server does not support dynamic registration and no pre-registered + /// credentials were supplied. + case registrationInformationRequired + + /// The authorization server does not support Client ID Metadata Documents and no dynamic + /// registration endpoint is available. + /// + /// - Parameter clientID: The client ID URL that was provided as a CIMD URL. + case cimdNotSupported(clientID: String) + + /// A URL received from a discovery response resolves to a private or reserved IP address, + /// which is blocked to prevent SSRF attacks. + /// + /// - Parameters: + /// - context: Human-readable label identifying which URL was blocked. + /// - url: The blocked URL string. + case privateIPAddressBlocked(context: String, url: String) + + /// An endpoint URL used during the OAuth flow does not satisfy the HTTPS-or-loopback requirement. + /// + /// - Parameters: + /// - context: Human-readable label identifying which endpoint failed (e.g., `"Token endpoint"`). + /// - url: The offending URL string. + case insecureOAuthEndpoint(context: String, url: String) + + /// An authorization server endpoint does not satisfy the HTTPS-only requirement. + /// + /// - Parameters: + /// - context: Human-readable label identifying which endpoint failed. + /// - url: The offending URL string. + case insecureAuthorizationServerEndpoint(context: String, url: String) + + /// The redirect URI supplied for the `authorization_code` flow is not a valid HTTPS or + /// loopback HTTP URI, or it contains a fragment. + case invalidRedirectURI(String) + + /// The authorization endpoint returned an HTTP error response during the authorization code flow. + /// + /// - Parameter statusCode: HTTP status code from the authorization endpoint. + case authorizationRequestFailed(statusCode: Int) + + /// The authorization response did not include a `Location` redirect header. + case authorizationResponseMissingRedirectLocation + + /// The redirect URI in the authorization response does not match the expected redirect URI. + /// + /// - Parameters: + /// - expected: The redirect URI supplied in the authorization request. + /// - actual: The redirect URI received in the authorization response. + case authorizationResponseRedirectMismatch(expected: String, actual: String) + + /// The authorization response redirect URL is missing the `state` parameter. + case authorizationResponseMissingState + + /// The `state` in the authorization response does not match the one sent in the request. + /// + /// This may indicate a CSRF attack. + /// + /// - Parameters: + /// - expected: The `state` value sent in the authorization request. + /// - actual: The `state` value received in the authorization response. + case authorizationResponseStateMismatch(expected: String, actual: String) + + /// The authorization response redirect URL is missing the `code` parameter. + case authorizationResponseMissingCode + + /// The authorization server metadata does not include `code_challenge_methods_supported`, + /// which is required for PKCE (RFC 7636). + case pkceCodeChallengeMethodsMissing + + /// The authorization server does not advertise `S256` in `code_challenge_methods_supported`. + /// + /// The MCP specification mandates S256; plain PKCE is not accepted. + /// + /// - Parameter advertisedMethods: The methods listed in the server metadata. + case pkceS256NotSupported(advertisedMethods: [String]) + + /// PKCE S256 challenge generation is unavailable because CryptoKit is not present on this platform. + case pkceS256Unavailable + + /// The `issuer` field in the Authorization Server Metadata does not match the URL used to + /// resolve the metadata document, as required by RFC 8414 §3. + /// + /// - Parameters: + /// - expected: The issuer URL derived from the discovery candidate. + /// - actual: The `issuer` field value found in the metadata document. + case authorizationServerIssuerMismatch(expected: String, actual: String) + + public var errorDescription: String? { + switch self { + case .missingAuthorizationServer: + return "No authorization server was found in protected resource metadata" + case .metadataDiscoveryFailed: + return "Failed to discover protected resource metadata" + case .authorizationServerMetadataDiscoveryFailed: + return "Failed to discover authorization server metadata" + case .tokenEndpointMissing: + return "Authorization server metadata is missing token_endpoint" + case .tokenRequestFailed(let statusCode, let oauthError): + if let oauthError, !oauthError.isEmpty { + return "Token request failed with status \(statusCode) (oauth_error: \(oauthError))" + } + return "Token request failed with status \(statusCode)" + case .tokenResponseInvalid: + return "Token response is invalid" + case .invalidResourceURI(let detail): + return "Invalid resource URI: \(detail)" + case .invalidClientIDMetadataURL(let value): + return + "Client ID metadata document URL must use https and include a path: \(value)" + case .protectedResourceMismatch(let expected, let actual): + return + "Protected resource metadata resource mismatch. Expected \(expected), got \(actual)" + case .registrationInformationRequired: + return + "No supported client registration mechanism was available; provide pre-registered client credentials" + case .cimdNotSupported(let clientID): + return + "Authorization server does not support Client ID Metadata Documents; configure pre-registered credentials or ensure the server advertises client_id_metadata_document_supported: \(clientID)" + case .privateIPAddressBlocked(let context, let url): + return + "\(context) resolves to a private or reserved IP address which is blocked for SSRF protection: \(url)" + case .insecureOAuthEndpoint(let context, let url): + return "\(context) must use https or loopback http: \(url)" + case .insecureAuthorizationServerEndpoint(let context, let url): + return "\(context) must use https: \(url)" + case .invalidRedirectURI(let url): + return + "Redirect URI must use https or loopback http and must not include fragments: \(url)" + case .authorizationRequestFailed(let statusCode): + return "Authorization request failed with status \(statusCode)" + case .authorizationResponseMissingRedirectLocation: + return "Authorization response is missing redirect location" + case .authorizationResponseRedirectMismatch(let expected, let actual): + return + "Authorization response redirect URI mismatch. Expected \(expected), got \(actual)" + case .authorizationResponseMissingState: + return "Authorization response is missing state" + case .authorizationResponseStateMismatch(let expected, let actual): + return "Authorization response state mismatch. Expected \(expected), got \(actual)" + case .authorizationResponseMissingCode: + return "Authorization response is missing the authorization code" + case .pkceCodeChallengeMethodsMissing: + return + "Authorization server metadata must include code_challenge_methods_supported for PKCE" + case .pkceS256NotSupported(let advertisedMethods): + let methods = advertisedMethods.joined(separator: ", ") + return + "Authorization server metadata must support PKCE S256 (advertised: \(methods))" + case .pkceS256Unavailable: + return + "PKCE S256 code challenge generation is unavailable on this platform" + case .authorizationServerIssuerMismatch(let expected, let actual): + return + "Authorization server issuer mismatch. Expected \(expected), got \(actual)" + } + } +} diff --git a/Sources/MCP/Base/Authorization/OAuthModels.swift b/Sources/MCP/Base/Authorization/OAuthModels.swift new file mode 100644 index 00000000..d5451c3d --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthModels.swift @@ -0,0 +1,247 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +struct HTTPAuthenticationChallengeError: Error { + let statusCode: Int + let headers: [String: String] +} + +/// Parsed representation of a `WWW-Authenticate: Bearer` challenge header. +/// +/// Servers return this challenge in `401 Unauthorized` and `403 Forbidden` responses +/// to indicate that a Bearer token is required or that the presented token lacks +/// sufficient scope. +/// +/// The ``OAuthWWWAuthenticateParsing`` protocol produces instances of this type. +public struct OAuthBearerChallenge: Sendable { + /// Raw key-value parameters extracted from the `Bearer` challenge. + public let parameters: [String: String] + + /// Creates a challenge from raw parsed parameters. + /// - Parameter parameters: Key-value pairs from the `WWW-Authenticate` header. + public init(parameters: [String: String]) { + self.parameters = parameters + } + + /// The `resource_metadata` parameter, parsed as a URL. + /// + /// Points to the server's RFC 9728 Protected Resource Metadata document. + /// When present, ``OAuthAuthorizer`` uses this URL as the highest-priority + /// discovery candidate. + public var resourceMetadataURL: URL? { + guard let value = parameters["resource_metadata"] else { return nil } + return URL(string: value) + } + + /// The `scope` parameter from the Bearer challenge. + /// + /// Specifies the scopes required or recommended for this resource. + /// ``OAuthAuthorizer`` uses this value as the highest-priority scope hint. + public var scope: String? { + parameters["scope"] + } + + /// The `error` parameter from the Bearer challenge (e.g., `"invalid_token"`, `"insufficient_scope"`). + public var error: String? { + parameters["error"] + } + + /// The `error_description` parameter from the Bearer challenge. + public var errorDescription: String? { + parameters["error_description"] + } +} + +/// RFC9728 OAuth Protected Resource metadata (client-side, decode-only). +struct OAuthProtectedResourceMetadata: Decodable, Sendable, Equatable { + let resource: String? + let authorizationServers: [URL] + let scopesSupported: [String]? + + enum CodingKeys: String, CodingKey { + case resource + case authorizationServers = "authorization_servers" + case scopesSupported = "scopes_supported" + } +} + +/// RFC8414/OIDC authorization server metadata. +struct OAuthAuthorizationServerMetadata: Decodable, Sendable, Equatable { + let issuer: URL? + let authorizationEndpoint: URL? + let tokenEndpoint: URL? + let registrationEndpoint: URL? + let codeChallengeMethodsSupported: [String]? + let tokenEndpointAuthMethodsSupported: [String]? + let clientIDMetadataDocumentSupported: Bool? + + enum CodingKeys: String, CodingKey { + case issuer + case authorizationEndpoint = "authorization_endpoint" + case tokenEndpoint = "token_endpoint" + case registrationEndpoint = "registration_endpoint" + case codeChallengeMethodsSupported = "code_challenge_methods_supported" + case tokenEndpointAuthMethodsSupported = "token_endpoint_auth_methods_supported" + case clientIDMetadataDocumentSupported = "client_id_metadata_document_supported" + } +} + +struct OAuthTokenResponse: Decodable, Sendable, Equatable { + let accessToken: String + let tokenType: String + let expiresIn: Int? + let scope: String? + let refreshToken: String? + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case tokenType = "token_type" + case expiresIn = "expires_in" + case scope + case refreshToken = "refresh_token" + } +} + +struct OAuthTokenErrorResponse: Decodable { + let error: String +} + +/// An OAuth 2.1 access token and its associated metadata. +/// +/// Stored by ``TokenStorage`` and produced by ``OAuthAuthorizer`` after a successful +/// token request. Use ``isExpired(now:skewSeconds:)`` to check validity before use. +public struct OAuthAccessToken: Sendable { + /// The raw bearer token string for use in the `Authorization` header. + public let value: String + + /// The token type returned by the authorization server (should be `"Bearer"`). + public let tokenType: String + + /// The UTC date after which the token is considered expired, or `nil` if no expiry was specified. + public let expiresAt: Date? + + /// The set of OAuth scopes granted with this token. + public let scopes: Set + + /// The issuer URL of the authorization server that issued this token. + /// + /// Used to detect when the active authorization server changes between requests, + /// triggering a token invalidation. + public let authorizationServer: URL? + + /// The refresh token, if the authorization server issued one alongside the access token. + public let refreshToken: String? + + /// Creates a new access token record. + public init( + value: String, + tokenType: String, + expiresAt: Date?, + scopes: Set, + authorizationServer: URL?, + refreshToken: String? + ) { + self.value = value + self.tokenType = tokenType + self.expiresAt = expiresAt + self.scopes = scopes + self.authorizationServer = authorizationServer + self.refreshToken = refreshToken + } + + /// Returns `true` if the token has expired or will expire within the skew window. + /// + /// - Parameters: + /// - now: The reference time to compare against. Defaults to `Date()`. + /// - skewSeconds: Clock skew buffer in seconds. Defaults to ``OAuthTokenExpirySkew/defaultSeconds`` (30 s). + /// Tokens are considered expired when `now + skewSeconds >= expiresAt`. + /// - Returns: `false` if `expiresAt` is `nil` (no expiry). + public func isExpired(now: Date = Date(), skewSeconds: TimeInterval = OAuthTokenExpirySkew.defaultSeconds) -> Bool { + guard let expiresAt else { return false } + return now.addingTimeInterval(skewSeconds) >= expiresAt + } +} + +/// Token introspection result passed from the caller's token validator to ``BearerTokenValidator``. +/// +/// The validator uses this to enforce expiry and audience checks before allowing a request through. +/// Produce an instance in your ``BearerTokenValidator/TokenValidator`` closure after verifying +/// the token's signature and extracting its claims. +public struct BearerTokenInfo: Sendable, Equatable { + /// Audience values from the token (`aud` JWT claim or introspection response). + /// + /// `nil` indicates an opaque token whose audience cannot be inspected; + /// ``BearerTokenValidator`` skips the audience check in that case. + public let audience: [String]? + + /// Scopes granted by the token. + public let scopes: Set? + + /// UTC date after which the token is considered expired. + /// + /// `nil` means no expiry information is available and the expiry check is skipped. + public let expiresAt: Date? + + public init( + audience: [String]? = nil, + scopes: Set? = nil, + expiresAt: Date? = nil + ) { + self.audience = audience + self.scopes = scopes + self.expiresAt = expiresAt + } +} + +struct OAuthClientRegistrationResponse: Decodable { + let clientID: String + let clientSecret: String? + let tokenEndpointAuthMethod: String? + /// Unix timestamp after which `clientSecret` is no longer valid, per RFC 7591 §3.2. + /// A value of `0` means the secret does not expire. + let clientSecretExpiresAt: Int? + + enum CodingKeys: String, CodingKey { + case clientID = "client_id" + case clientSecret = "client_secret" + case tokenEndpointAuthMethod = "token_endpoint_auth_method" + case clientSecretExpiresAt = "client_secret_expires_at" + } +} + +/// Server-side encodable RFC 9728 Protected Resource Metadata. +/// +/// Use this type to construct the metadata document that MCP servers **MUST** serve +/// at `/.well-known/oauth-protected-resource` per the MCP authorization specification. +/// +/// Pair with ``ProtectedResourceMetadataValidator`` to automatically serve this +/// document in the server's validation pipeline. +public struct OAuthProtectedResourceServerMetadata: Codable, Sendable { + /// The canonical resource identifier (RFC 8707). + public let resource: String + + /// One or more authorization server URLs that protect this resource. + public let authorizationServers: [URL] + + /// The scopes supported by this resource server. + public let scopesSupported: [String]? + + enum CodingKeys: String, CodingKey { + case resource + case authorizationServers = "authorization_servers" + case scopesSupported = "scopes_supported" + } + + public init( + resource: String, + authorizationServers: [URL], + scopesSupported: [String]? = nil + ) { + self.resource = resource + self.authorizationServers = authorizationServers + self.scopesSupported = scopesSupported + } +} diff --git a/Sources/MCP/Base/Authorization/OAuthTokenEndpointClient.swift b/Sources/MCP/Base/Authorization/OAuthTokenEndpointClient.swift new file mode 100644 index 00000000..3c616002 --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthTokenEndpointClient.swift @@ -0,0 +1,90 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Internal protocol for making OAuth token endpoint requests. +protocol OAuthTokenRequesting: Sendable { + func request( + parameters: inout [String: String], + endpoint: URL, + authentication: OAuthConfiguration.TokenEndpointAuthentication, + session: URLSession + ) async throws -> OAuthTokenResponse +} + +/// Stateless OAuth token endpoint HTTP client. +/// +/// Handles the low-level HTTP mechanics of making token requests. +struct OAuthTokenEndpointClient: Sendable { + let urlValidator: OAuthURLValidator + + init(urlValidator: OAuthURLValidator) { + self.urlValidator = urlValidator + } + + /// Makes a token request to the given endpoint. + func request( + parameters: inout [String: String], + endpoint: URL, + authentication: OAuthConfiguration.TokenEndpointAuthentication, + session: URLSession + ) async throws -> OAuthTokenResponse { + var urlRequest = URLRequest(url: endpoint) + urlRequest.httpMethod = "POST" + urlRequest.setValue(ContentType.formURLEncoded, forHTTPHeaderField: HTTPHeaderName.contentType) + urlRequest.setValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.accept) + + try await authentication.apply( + to: &urlRequest, + bodyParameters: ¶meters, + tokenEndpoint: endpoint + ) + urlRequest.httpBody = encodeForm(parameters) + + let (data, response) = try await session.data(for: urlRequest) + guard let httpResponse = response as? HTTPURLResponse else { + throw OAuthAuthorizationError.tokenRequestFailed(statusCode: -1, oauthError: nil) + } + + guard (200..<300).contains(httpResponse.statusCode) else { + let oauthError = + (try? JSONDecoder().decode(OAuthTokenErrorResponse.self, from: data))?.error + throw OAuthAuthorizationError.tokenRequestFailed( + statusCode: httpResponse.statusCode, + oauthError: oauthError + ) + } + + let decoded = try JSONDecoder().decode(OAuthTokenResponse.self, from: data) + guard !decoded.accessToken.isEmpty else { + throw OAuthAuthorizationError.tokenResponseInvalid + } + let tokenType = decoded.tokenType.trimmingCharacters(in: .whitespacesAndNewlines) + guard !tokenType.isEmpty, + tokenType.caseInsensitiveCompare(OAuthTokenType.bearer) == .orderedSame + else { + throw OAuthAuthorizationError.tokenResponseInvalid + } + return decoded + } + + private func encodeForm(_ params: [String: String]) -> Data { + let body = params + .sorted { $0.key < $1.key } + .map { key, value in "\(percentEncode(key))=\(percentEncode(value))" } + .joined(separator: "&") + return Data(body.utf8) + } + + private func percentEncode(_ string: String) -> String { + let allowed = CharacterSet( + charactersIn: + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + ) + return string.addingPercentEncoding(withAllowedCharacters: allowed) ?? string + } +} + +extension OAuthTokenEndpointClient: OAuthTokenRequesting {} diff --git a/Sources/MCP/Base/Authorization/OAuthURLValidator.swift b/Sources/MCP/Base/Authorization/OAuthURLValidator.swift new file mode 100644 index 00000000..ddc180cf --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthURLValidator.swift @@ -0,0 +1,135 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Internal protocol for URL security validation in OAuth flows. +protocol OAuthURLValidating: Sendable { + func validateHTTPSOrLoopback(_ url: URL, context: String) throws + func validateAuthorizationServer(_ url: URL, context: String) throws + func validateRedirectURI(_ url: URL) throws + func isPrivateIPHost(_ host: String) -> Bool +} + +/// URL security rules for OAuth endpoints. +/// +/// Validates that URLs used in the OAuth flow satisfy HTTPS requirements and SSRF protections. +/// Configured once and shared across the discovery, token, and registration components. +public struct OAuthURLValidator: Sendable { + + /// When `true`, loopback HTTP URLs are accepted for authorization server endpoints. + /// + /// This is a package-level compatibility option for local test environments. + public let allowLoopbackHTTPForAuthorizationServer: Bool + + public init(allowLoopbackHTTPForAuthorizationServer: Bool = false) { + self.allowLoopbackHTTPForAuthorizationServer = allowLoopbackHTTPForAuthorizationServer + } + + /// Validates that the URL uses HTTPS or loopback HTTP, and has no fragment. + /// + /// Used for MCP endpoints and protected resource metadata URLs. + public func validateHTTPSOrLoopback(_ url: URL, context: String) throws { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + !host.isEmpty, + components.fragment == nil + else { + throw OAuthAuthorizationError.invalidResourceURI( + "Invalid \(context): \(url.absoluteString)" + ) + } + + guard scheme == OAuthURLScheme.https + || (scheme == OAuthURLScheme.http && OAuthLoopbackHost.isLoopback(host)) + else { + throw OAuthAuthorizationError.insecureOAuthEndpoint( + context: context, + url: url.absoluteString + ) + } + } + + /// Validates that the URL is an HTTPS authorization server endpoint. + /// + /// Loopback HTTP is permitted when `allowLoopbackHTTPForAuthorizationServer` is `true`. + public func validateAuthorizationServer(_ url: URL, context: String) throws { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + let host = components.host?.lowercased(), + !host.isEmpty, + components.fragment == nil + else { + throw OAuthAuthorizationError.invalidResourceURI( + "Invalid \(context): \(url.absoluteString)" + ) + } + + if allowLoopbackHTTPForAuthorizationServer, + scheme == OAuthURLScheme.http, + OAuthLoopbackHost.isLoopback(host) + { + return + } + + guard scheme == OAuthURLScheme.https else { + throw OAuthAuthorizationError.insecureAuthorizationServerEndpoint( + context: context, + url: url.absoluteString + ) + } + } + + /// Validates the redirect URI for authorization_code flows. + /// + /// Accepts HTTPS or loopback HTTP; rejects any URI containing a fragment. + public func validateRedirectURI(_ url: URL) throws { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false), + let scheme = components.scheme?.lowercased(), + components.fragment == nil + else { + throw OAuthAuthorizationError.invalidRedirectURI(url.absoluteString) + } + + if scheme == OAuthURLScheme.https { return } + + if scheme == OAuthURLScheme.http, + let host = components.host?.lowercased(), + OAuthLoopbackHost.isLoopback(host) + { + return + } + + throw OAuthAuthorizationError.invalidRedirectURI(url.absoluteString) + } + + /// Returns `true` if `host` is a literal IPv4 or IPv6 address in a private or reserved range. + /// + /// Blocked ranges: + /// - IPv4: `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`, `169.254.0.0/16`, `100.64.0.0/10` + /// - IPv6: `fc00::/7` (ULA), `fe80::/10` (link-local) + /// + /// **Limitation**: only literal IP addresses are checked. DNS rebinding is not prevented here. + public func isPrivateIPHost(_ host: String) -> Bool { + let octets = host.split(separator: ".").compactMap { UInt8($0) } + if octets.count == 4 && !host.contains(":") { + let (a, b) = (octets[0], octets[1]) + return a == 10 // 10.0.0.0/8 + || (a == 172 && (16...31).contains(b)) // 172.16.0.0/12 + || (a == 192 && b == 168) // 192.168.0.0/16 + || (a == 169 && b == 254) // 169.254.0.0/16 (link-local / cloud metadata) + || (a == 100 && (64...127).contains(b)) // 100.64.0.0/10 (CGNAT) + } + let lower = host.lowercased() + if lower.hasPrefix("fc") || lower.hasPrefix("fd") { return true } // fc00::/7 (ULA) + if lower.hasPrefix("fe") && lower.count > 2 { + let idx = lower.index(lower.startIndex, offsetBy: 2) + if "89ab".contains(lower[idx]) { return true } // fe80::/10 (link-local) + } + return false + } +} + +extension OAuthURLValidator: OAuthURLValidating {} diff --git a/Sources/MCP/Base/Authorization/OAuthWWWAuthenticateParser.swift b/Sources/MCP/Base/Authorization/OAuthWWWAuthenticateParser.swift new file mode 100644 index 00000000..51f3c6fa --- /dev/null +++ b/Sources/MCP/Base/Authorization/OAuthWWWAuthenticateParser.swift @@ -0,0 +1,189 @@ +import Foundation + +// MARK: - WWW-Authenticate Parsing Protocol + +/// Parses the `WWW-Authenticate` response header to extract Bearer challenge parameters. +/// +/// ``OAuthAuthorizer`` uses this protocol to inspect 401 and 403 responses from the server. +/// Override the default ``DefaultOAuthWWWAuthenticateParser`` to handle custom challenge formats. +public protocol OAuthWWWAuthenticateParsing { + /// Parses the `WWW-Authenticate` header from a response header dictionary. + /// + /// - Parameter headers: The full set of HTTP response headers (case-insensitive lookup is expected). + /// - Returns: An ``OAuthBearerChallenge`` if a `Bearer` scheme is found, or `nil` otherwise. + func parseBearer(from headers: [String: String]) -> OAuthBearerChallenge? +} + +// MARK: - Default Implementation + +/// Default ``OAuthWWWAuthenticateParsing`` implementation. +/// +/// Parses `WWW-Authenticate` headers following RFC 6750 §3, handling: +/// - Multiple challenge schemes in a single header value (e.g., `Bearer …, Basic …`) +/// - Quoted-string parameter values with backslash escaping +/// - Case-insensitive scheme and parameter key matching +public struct DefaultOAuthWWWAuthenticateParser: OAuthWWWAuthenticateParsing { + public init() {} + private let tokenCharacters = CharacterSet( + charactersIn: "!#$%&'*+-.^_`|~0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + ) + + public func parseBearer(from headers: [String: String]) -> OAuthBearerChallenge? { + guard let value = headers.first(where: { + $0.key.caseInsensitiveCompare(HTTPHeaderName.wwwAuthenticate) == .orderedSame + })?.value else { + return nil + } + return parseBearerHeader(value) + } + + private func parseBearerHeader(_ header: String) -> OAuthBearerChallenge? { + let trimmed = header.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return nil } + guard let parametersPart = extractBearerParameters(from: trimmed) else { return nil } + + if parametersPart.isEmpty { + return OAuthBearerChallenge(parameters: [:]) + } + + let components = splitParameters(parametersPart) + var parameters: [String: String] = [:] + + for component in components { + let pair = component.split(separator: "=", maxSplits: 1) + guard pair.count == 2 else { continue } + let key = pair[0].trimmingCharacters(in: .whitespacesAndNewlines).lowercased() + var value = pair[1].trimmingCharacters(in: .whitespacesAndNewlines) + if value.hasPrefix("\"") && value.hasSuffix("\"") && value.count >= 2 { + value.removeFirst() + value.removeLast() + value = value.replacingOccurrences(of: "\\\"", with: "\"") + } + parameters[key] = value + } + + return OAuthBearerChallenge(parameters: parameters) + } + + private func extractBearerParameters(from header: String) -> String? { + let segments = splitParameters(header) + + for index in segments.indices { + let segment = segments[index].trimmingCharacters(in: .whitespacesAndNewlines) + guard isBearerChallengeStart(segment) else { continue } + + var parameters: [String] = [] + let initial = stripBearerScheme(from: segment) + if !initial.isEmpty { + parameters.append(initial) + } + + var nextIndex = segments.index(after: index) + while nextIndex < segments.endIndex { + let next = segments[nextIndex].trimmingCharacters(in: .whitespacesAndNewlines) + if next.isEmpty { + nextIndex = segments.index(after: nextIndex) + continue + } + if startsNewChallenge(next) { + break + } + parameters.append(next) + nextIndex = segments.index(after: nextIndex) + } + + return parameters.joined(separator: ",") + } + + return nil + } + + private func isBearerChallengeStart(_ segment: String) -> Bool { + guard segment.count >= OAuthTokenType.bearer.count else { return false } + let schemeEnd = segment.index(segment.startIndex, offsetBy: OAuthTokenType.bearer.count) + let scheme = segment[.. String { + guard segment.count > OAuthTokenType.bearer.count else { return "" } + let index = segment.index(segment.startIndex, offsetBy: OAuthTokenType.bearer.count) + return String(segment[index...]).trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func startsNewChallenge(_ segment: String) -> Bool { + if isBearerChallengeStart(segment) { + return true + } + + if let equalsIndex = segment.firstIndex(of: "=") { + let parameterName = String(segment[.. Bool { + guard !value.isEmpty else { return false } + return value.rangeOfCharacter(from: tokenCharacters.inverted) == nil + } + + private func splitParameters(_ value: String) -> [String] { + var components: [String] = [] + var current = "" + var inQuotes = false + var escaping = false + + for character in value { + if escaping { + current.append(character) + escaping = false + continue + } + if character == "\\" { + current.append(character) + escaping = true + continue + } + if character == "\"" { + inQuotes.toggle() + current.append(character) + continue + } + if character == "," && !inQuotes { + components.append(current.trimmingCharacters(in: .whitespacesAndNewlines)) + current = "" + continue + } + current.append(character) + } + + if !current.isEmpty { + components.append(current.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + return components.filter { !$0.isEmpty } + } +} + diff --git a/Sources/MCP/Base/Authorization/PKCE.swift b/Sources/MCP/Base/Authorization/PKCE.swift new file mode 100644 index 00000000..45154542 --- /dev/null +++ b/Sources/MCP/Base/Authorization/PKCE.swift @@ -0,0 +1,62 @@ +import Foundation + +#if canImport(CryptoKit) + import CryptoKit +#endif + +/// Pure PKCE (RFC 7636) helpers required by the authorization_code flow. +public enum PKCE { + + /// Generates a cryptographically random PKCE code verifier. + /// + /// - Parameter length: Number of characters in the verifier. Defaults to 64. + /// RFC 7636 requires 43–128 characters. + /// - Returns: A URL-safe random string suitable for use as a PKCE code verifier. + public static func makeVerifier(length: Int = 64) -> String { + let charset = Array("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") + // 66 characters; 256 % 66 == 52, so reject bytes > 252 to eliminate modulo bias. + let limit = UInt8(255 - (255 % charset.count)) // 252 + var rng = SystemRandomNumberGenerator() + var result = "" + result.reserveCapacity(length) + while result.count < length { + let byte = UInt8.random(in: 0...255, using: &rng) + if byte <= limit { + result.append(charset[Int(byte % UInt8(charset.count))]) + } + } + return result + } + + /// Derives the PKCE S256 code challenge from a verifier. + /// + /// - Parameter verifier: A code verifier produced by ``makeVerifier(length:)``. + /// - Returns: The base64url-encoded SHA-256 hash of the verifier. + /// - Throws: ``OAuthAuthorizationError/pkceS256Unavailable`` on platforms without CryptoKit. + public static func makeChallenge(from verifier: String) throws -> String { + #if canImport(CryptoKit) + let hash = SHA256.hash(data: Data(verifier.utf8)) + return Data(hash).base64URLEncodedString() + #else + throw OAuthAuthorizationError.pkceS256Unavailable + #endif + } + + /// Verifies that the authorization server metadata advertises S256 PKCE support. + /// + /// - Parameter metadata: Authorization server metadata to inspect. + /// - Throws: ``OAuthAuthorizationError/pkceCodeChallengeMethodsMissing`` if + /// `code_challenge_methods_supported` is absent or empty, or + /// ``OAuthAuthorizationError/pkceS256NotSupported(advertisedMethods:)`` if S256 is not listed. + static func checkSupport(in metadata: OAuthAuthorizationServerMetadata) throws { + guard let methods = metadata.codeChallengeMethodsSupported, !methods.isEmpty else { + throw OAuthAuthorizationError.pkceCodeChallengeMethodsMissing + } + let supportsS256 = methods.contains { + $0.caseInsensitiveCompare(OAuthCodeChallengeMethod.s256) == .orderedSame + } + guard supportsS256 else { + throw OAuthAuthorizationError.pkceS256NotSupported(advertisedMethods: methods) + } + } +} diff --git a/Sources/MCP/Base/Authorization/TokenStorage.swift b/Sources/MCP/Base/Authorization/TokenStorage.swift new file mode 100644 index 00000000..07a5e3c9 --- /dev/null +++ b/Sources/MCP/Base/Authorization/TokenStorage.swift @@ -0,0 +1,37 @@ +import Foundation + +// MARK: - Token Storage Protocol + +/// Abstraction for persisting OAuth access tokens. +/// +/// Implement this protocol to provide custom token storage (e.g., Keychain-backed). +/// The default ``InMemoryTokenStorage`` stores tokens in memory only. +public protocol TokenStorage: AnyObject, Sendable { + func save(_ token: OAuthAccessToken) + func load() -> OAuthAccessToken? + func clear() +} + +// MARK: - In-Memory Implementation + +/// Default ``TokenStorage`` that stores the access token in memory only. +/// +/// The token is lost when the process exits. For persistent storage +/// (e.g., system Keychain), implement ``TokenStorage`` directly. +public final class InMemoryTokenStorage: TokenStorage, @unchecked Sendable { + private var token: OAuthAccessToken? + + public init() {} + + public func save(_ token: OAuthAccessToken) { + self.token = token + } + + public func load() -> OAuthAccessToken? { + token + } + + public func clear() { + token = nil + } +} diff --git a/Sources/MCP/Base/Transports/HTTPClientTransport.swift b/Sources/MCP/Base/Transports/HTTPClientTransport.swift index 9f4908ec..060cd133 100644 --- a/Sources/MCP/Base/Transports/HTTPClientTransport.swift +++ b/Sources/MCP/Base/Transports/HTTPClientTransport.swift @@ -76,6 +76,9 @@ public actor HTTPClientTransport: Transport { /// Closure to modify requests before they are sent private let requestModifier: (URLRequest) -> URLRequest + /// Optional OAuth 2.1 authorizer. + private let authorizer: (any HTTPClientAuthorizer)? + private var isConnected = false private let messageStream: AsyncThrowingStream private let messageContinuation: AsyncThrowingStream.Continuation @@ -101,6 +104,7 @@ public actor HTTPClientTransport: Transport { /// - streaming: Whether to enable SSE streaming mode (default: true) /// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds) /// - protocolVersion: The MCP protocol version to use (default: "2025-11-25") + /// - authorizer: Optional ``HTTPClientAuthorizer`` for automatic Bearer token acquisition and retries. /// - requestModifier: Optional closure to customize requests before they are sent (default: no modification) /// - logger: Optional logger instance for transport events public init( @@ -109,6 +113,7 @@ public actor HTTPClientTransport: Transport { streaming: Bool = true, sseInitializationTimeout: TimeInterval = 10, protocolVersion: String = Version.latest, + authorizer: (any HTTPClientAuthorizer)? = nil, requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { @@ -118,6 +123,7 @@ public actor HTTPClientTransport: Transport { streaming: streaming, sseInitializationTimeout: sseInitializationTimeout, protocolVersion: protocolVersion, + authorizer: authorizer, requestModifier: requestModifier, logger: logger ) @@ -129,6 +135,7 @@ public actor HTTPClientTransport: Transport { streaming: Bool = false, sseInitializationTimeout: TimeInterval = 10, protocolVersion: String = Version.latest, + authorizer: (any HTTPClientAuthorizer)? = nil, requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { @@ -138,6 +145,7 @@ public actor HTTPClientTransport: Transport { self.sseInitializationTimeout = sseInitializationTimeout self.protocolVersion = protocolVersion self.requestModifier = requestModifier + self.authorizer = authorizer // Create message stream var continuation: AsyncThrowingStream.Continuation! @@ -157,7 +165,6 @@ public actor HTTPClientTransport: Transport { self.initialSessionIDSignalTask = Task { await withCheckedContinuation { continuation in self.initialSessionIDContinuation = continuation - // This task will suspend here until continuation.resume() is called } } } @@ -166,7 +173,7 @@ public actor HTTPClientTransport: Transport { private func triggerInitialSessionIDSignal() { if let continuation = self.initialSessionIDContinuation { continuation.resume() - self.initialSessionIDContinuation = nil // Consume the continuation + self.initialSessionIDContinuation = nil logger.debug("✓ Initial session ID signal triggered for SSE task") } else { logger.debug("✗ No continuation to trigger - signal already consumed or SSE task not waiting") @@ -174,19 +181,13 @@ public actor HTTPClientTransport: Transport { } /// Establishes connection with the transport - /// - /// This prepares the transport for communication and sets up SSE streaming - /// if streaming mode is enabled. The actual HTTP connection happens with the - /// first message sent. public func connect() async throws { guard !isConnected else { return } isConnected = true - // Setup initial session ID signal setupInitialSessionIDSignal() if streaming { - // Start listening to server events streamingTask = Task { await startListeningForServerEvents() } } @@ -194,27 +195,18 @@ public actor HTTPClientTransport: Transport { } /// Disconnects from the transport - /// - /// This terminates any active connections, cancels the streaming task, - /// and releases any resources being used by the transport. public func disconnect() async { guard isConnected else { return } isConnected = false - // Cancel streaming task if active streamingTask?.cancel() streamingTask = nil - // Cancel any in-progress requests session.invalidateAndCancel() - - // Clean up message stream messageContinuation.finish() - // Cancel the initial session ID signal task if active initialSessionIDSignalTask?.cancel() initialSessionIDSignalTask = nil - // Resume the continuation if it's still pending to avoid leaks initialSessionIDContinuation?.resume() initialSessionIDContinuation = nil @@ -222,77 +214,108 @@ public actor HTTPClientTransport: Transport { } /// Updates the protocol version used for `MCP-Protocol-Version` headers on subsequent requests. - /// - /// Call this after lifecycle negotiation if the server responds with a different version. - /// - Parameter version: The negotiated protocol version. public func updateNegotiatedProtocolVersion(_ version: String) { self.protocolVersion = version } /// Sends data through an HTTP POST request - /// - /// This sends a JSON-RPC message to the server via HTTP POST and processes - /// the response according to the MCP Streamable HTTP specification. It handles: - /// - /// - Adding appropriate Accept headers for both JSON and SSE - /// - Including the MCP-Protocol-Version header as required by the specification - /// - Including the session ID in requests if one has been established - /// - Processing different response types (JSON vs SSE) - /// - Handling HTTP error codes according to the specification - /// - /// - Parameter data: The JSON-RPC message to send - /// - Throws: MCPError for transport failures or server errors public func send(_ data: Data) async throws { guard isConnected else { throw MCPError.internalError("Transport not connected") } - var request = URLRequest(url: endpoint) - request.httpMethod = "POST" - request.addValue("application/json, text/event-stream", forHTTPHeaderField: "Accept") - request.addValue("application/json", forHTTPHeaderField: "Content-Type") - request.httpBody = data - - // Add protocol version header (required by MCP specification 2025-11-25) - if let protocolVersion = protocolVersion { - request.addValue(protocolVersion, forHTTPHeaderField: "MCP-Protocol-Version") + if let authorizer { + do { + try authorizer.validateEndpointSecurity(for: endpoint) + } catch { + throw MCPError.internalError( + "Authorization flow failed: \(error.localizedDescription)" + ) + } } - // Add session ID if available - if let sessionID = sessionID { - request.addValue(sessionID, forHTTPHeaderField: "MCP-Session-Id") + if let authorizer { + try? await authorizer.prepareAuthorization(for: endpoint, session: session) } - // Apply request modifier - request = requestModifier(request) + var attempts = 0 + let operationKey = jsonRPCOperationKey(from: data) - #if os(Linux) - // Linux implementation using data(for:) instead of bytes(for:) - let (responseData, response) = try await session.data(for: request) - try await processResponse(response: response, data: responseData) - #else - // macOS and other platforms with bytes(for:) support - let (responseStream, response) = try await session.bytes(for: request) - try await processResponse(response: response, stream: responseStream) - #endif + while true { + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + request.addValue( + "\(ContentType.json), \(ContentType.sse)", + forHTTPHeaderField: HTTPHeaderName.accept + ) + request.addValue(ContentType.json, forHTTPHeaderField: HTTPHeaderName.contentType) + request.httpBody = data + + if let protocolVersion = protocolVersion { + request.addValue(protocolVersion, forHTTPHeaderField: HTTPHeaderName.protocolVersion) + } + + if let sessionID = sessionID { + request.addValue(sessionID, forHTTPHeaderField: HTTPHeaderName.sessionID) + } + + if let authValue = authorizer?.authorizationHeader(for: endpoint) { + request.setValue(authValue, forHTTPHeaderField: HTTPHeaderName.authorization) + } + + request = requestModifier(request) + + do { + #if os(Linux) + let (responseData, response) = try await session.data(for: request) + try await processResponse(response: response, data: responseData) + #else + let (responseStream, response) = try await session.bytes(for: request) + try await processResponse(response: response, stream: responseStream) + #endif + return + } catch let authError as HTTPAuthenticationChallengeError { + guard let authorizer else { + throw mapAuthenticationChallengeError(authError) + } + + let handled: Bool + do { + handled = try await authorizer.handleChallenge( + statusCode: authError.statusCode, + headers: authError.headers, + endpoint: endpoint, + operationKey: operationKey, + session: session + ) + } catch { + throw MCPError.internalError( + "Authorization flow failed: \(error.localizedDescription)") + } + + attempts += 1 + + if handled, attempts < authorizer.maxAuthorizationAttempts { + continue + } + + throw mapAuthenticationChallengeError(authError) + } + } } #if os(Linux) - // Process response with data payload (Linux) private func processResponse(response: URLResponse, data: Data) async throws { guard let httpResponse = response as? HTTPURLResponse else { throw MCPError.internalError("Invalid HTTP response") } - // Process the response based on content type and status code - let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") ?? "" + let contentType = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.contentType) ?? "" - // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.sessionID) { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { - // Trigger signal on first session ID triggerInitialSessionIDSignal() } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) @@ -301,11 +324,10 @@ public actor HTTPClientTransport: Transport { try processHTTPResponse(httpResponse, contentType: contentType) guard case 200..<300 = httpResponse.statusCode else { return } - // For JSON responses, yield the data - if contentType.contains("text/event-stream") { + if contentType.contains(ContentType.sse) { logger.warning("SSE responses aren't fully supported on Linux") messageContinuation.yield(data) - } else if contentType.contains("application/json") { + } else if contentType.contains(ContentType.json) { logger.trace("Received JSON response", metadata: ["size": "\(data.count)"]) messageContinuation.yield(data) } else { @@ -313,7 +335,6 @@ public actor HTTPClientTransport: Transport { } } #else - // Process response with byte stream (macOS, iOS, etc.) private func processResponse(response: URLResponse, stream: URLSession.AsyncBytes) async throws { @@ -321,15 +342,12 @@ public actor HTTPClientTransport: Transport { throw MCPError.internalError("Invalid HTTP response") } - // Process the response based on content type and status code - let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") ?? "" + let contentType = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.contentType) ?? "" - // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.sessionID) { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { - // Trigger signal on first session ID triggerInitialSessionIDSignal() } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) @@ -338,20 +356,15 @@ public actor HTTPClientTransport: Transport { try processHTTPResponse(httpResponse, contentType: contentType) guard case 200..<300 = httpResponse.statusCode else { return } - if contentType.contains("text/event-stream") { - // For SSE, processing happens via the stream + if contentType.contains(ContentType.sse) { logger.trace("Received SSE response, processing in streaming task") let hadData = try await self.processSSE(stream) - // If the POST SSE stream closed without delivering a JSON-RPC response, - // trigger GET reconnection so the server can deliver it there. - // This implements standard SSE reconnection behavior per the spec. if !hadData { logger.debug("POST SSE stream closed without data, triggering GET reconnection") self.activeGETSessionTask?.cancel() } - } else if contentType.contains("application/json") { - // For JSON responses, collect and deliver the data + } else if contentType.contains(ContentType.json) { var buffer = Data() for try await byte in stream { buffer.append(byte) @@ -364,25 +377,27 @@ public actor HTTPClientTransport: Transport { } #endif - // Common HTTP response handling for all platforms private func processHTTPResponse(_ response: HTTPURLResponse, contentType: String) throws { - // Handle status codes according to HTTP semantics switch response.statusCode { case 200..<300: - // Success range - these are handled by the platform-specific code return case 400: throw MCPError.internalError("Bad request") case 401: - throw MCPError.internalError("Authentication required") + throw HTTPAuthenticationChallengeError( + statusCode: response.statusCode, + headers: responseHeaders(from: response) + ) case 403: - throw MCPError.internalError("Access forbidden") + throw HTTPAuthenticationChallengeError( + statusCode: response.statusCode, + headers: responseHeaders(from: response) + ) case 404: - // If we get a 404 with a session ID, it means our session is invalid if sessionID != nil { logger.warning("Session has expired") sessionID = nil @@ -391,8 +406,6 @@ public actor HTTPClientTransport: Transport { throw MCPError.internalError("Endpoint not found") case 405: - // If we get a 405, it means the server does not support the requested method - // If streaming was requested, we should cancel the streaming task if streaming { self.streamingTask?.cancel() throw MCPError.internalError("Server does not support streaming") @@ -406,7 +419,6 @@ public actor HTTPClientTransport: Transport { throw MCPError.internalError("Too many requests") case 500..<600: - // Server error range throw MCPError.internalError("Server error: \(response.statusCode)") default: @@ -415,43 +427,55 @@ public actor HTTPClientTransport: Transport { } } + private func mapAuthenticationChallengeError(_ error: HTTPAuthenticationChallengeError) -> MCPError { + switch error.statusCode { + case 401: + return MCPError.internalError("Authentication required") + case 403: + return MCPError.internalError("Access forbidden") + default: + return MCPError.internalError("HTTP authorization error: \(error.statusCode)") + } + } + + private func responseHeaders(from response: HTTPURLResponse) -> [String: String] { + var headers: [String: String] = [:] + for (key, value) in response.allHeaderFields { + guard let key = key as? String, let value = value as? String else { continue } + headers[key] = value + } + return headers + } + + private func jsonRPCOperationKey(from data: Data) -> String? { + guard + let jsonObject = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let method = jsonObject["method"] as? String + else { + return nil + } + + let normalized = method.trimmingCharacters(in: .whitespacesAndNewlines) + return normalized.isEmpty ? nil : normalized + } + /// Receives data in an async sequence - /// - /// This returns an AsyncThrowingStream that emits Data objects representing - /// each JSON-RPC message received from the server. This includes: - /// - /// - Direct responses to client requests - /// - Server-initiated messages delivered via SSE streams - /// - /// - Returns: An AsyncThrowingStream of Data objects public func receive() -> AsyncThrowingStream { return messageStream } // MARK: - SSE - /// Starts listening for server events using SSE - /// - /// This establishes a long-lived HTTP connection using Server-Sent Events (SSE) - /// to enable server-to-client push messaging. It handles: - /// - /// - Waiting for session ID if needed - /// - Opening the SSE connection - /// - Automatic reconnection on connection drops - /// - Processing received events private func startListeningForServerEvents() async { #if os(Linux) - // SSE is not fully supported on Linux if streaming { logger.warning( "SSE streaming was requested but is not fully supported on Linux. SSE connection will not be attempted." ) } #else - // This is the original code for platforms that support SSE guard isConnected else { return } - // Wait for session ID to be available before opening SSE stream if self.sessionID == nil, let signalTask = self.initialSessionIDSignalTask { logger.debug("⏳ Waiting for session ID to be set (timeout: \(self.sseInitializationTimeout)s)...") @@ -484,7 +508,6 @@ public actor HTTPClientTransport: Transport { logger.debug("✓ Session ID already available, proceeding with SSE connection immediately") } - // Retry loop for connection drops var isFirstAttempt = true var attemptCount = 0 @@ -501,7 +524,6 @@ public actor HTTPClientTransport: Transport { ]) do { - // Wait for retry interval before reconnecting (except first attempt) if !isFirstAttempt { let delayMs = self.retryInterval logger.debug("⏳ Waiting before SSE reconnection", metadata: ["retryMs": "\(delayMs)"]) @@ -514,8 +536,6 @@ public actor HTTPClientTransport: Transport { try await self.connectToEventStream() - // If connectToEventStream() returns without error, - // it means the stream closed gracefully - reconnect with retry interval logger.info("🔌 SSE stream closed gracefully, will reconnect", metadata: [ "attempt": "\(attemptCount)", "willRetryAfter": "\(self.retryInterval)ms" @@ -523,7 +543,6 @@ public actor HTTPClientTransport: Transport { } catch { if !Task.isCancelled { logger.error("❌ SSE connection error (attempt #\(attemptCount)): \(error)") - // Error case - will also use retry interval on next iteration } else { logger.debug("⏹️ SSE task cancelled") } @@ -545,50 +564,56 @@ public actor HTTPClientTransport: Transport { } #if !os(Linux) - /// Establishes an SSE connection to the server - /// - /// This initiates a GET request to the server endpoint with appropriate - /// headers to establish an SSE stream according to the MCP specification. - /// Supports stream resumability via Last-Event-ID header. - /// - /// - Throws: MCPError for connection failures or server errors private func connectToEventStream() async throws { guard isConnected else { logger.debug("⚠️ Skipping connectToEventStream - transport not connected") return } + if let authorizer { + do { + try authorizer.validateEndpointSecurity(for: endpoint) + } catch { + throw MCPError.internalError( + "Authorization flow failed: \(error.localizedDescription)" + ) + } + } + + if let authorizer { + try? await authorizer.prepareAuthorization(for: endpoint, session: session) + } + logger.debug("🔌 Preparing SSE connection request") var request = URLRequest(url: endpoint) request.httpMethod = "GET" - request.addValue("text/event-stream", forHTTPHeaderField: "Accept") - request.addValue("no-cache", forHTTPHeaderField: "Cache-Control") + request.addValue(ContentType.sse, forHTTPHeaderField: HTTPHeaderName.accept) + request.addValue("no-cache", forHTTPHeaderField: HTTPHeaderName.cacheControl) - // Add protocol version header (required by MCP specification 2025-11-25) if let protocolVersion = protocolVersion { - request.addValue(protocolVersion, forHTTPHeaderField: "MCP-Protocol-Version") + request.addValue(protocolVersion, forHTTPHeaderField: HTTPHeaderName.protocolVersion) } - // Add session ID if available if let sessionID = sessionID { - request.addValue(sessionID, forHTTPHeaderField: "MCP-Session-Id") + request.addValue(sessionID, forHTTPHeaderField: HTTPHeaderName.sessionID) } - // Add last event ID for resumability (if available) if let lastEventID = lastEventID { - request.addValue(lastEventID, forHTTPHeaderField: "Last-Event-ID") + request.addValue(lastEventID, forHTTPHeaderField: HTTPHeaderName.lastEventID) logger.info("→ Resuming SSE stream with Last-Event-ID", metadata: ["lastEventID": "\(lastEventID)"]) } else { logger.info("→ Connecting to SSE stream (no last event ID to resume from)") } - // Apply request modifier + if let authValue = authorizer?.authorizationHeader(for: endpoint) { + request.setValue(authValue, forHTTPHeaderField: HTTPHeaderName.authorization) + } + request = requestModifier(request) logger.debug("Starting SSE connection") - // Create URLSession task for SSE let (stream, response) = try await session.bytes(for: request) self.activeGETSessionTask = stream.task @@ -596,24 +621,17 @@ public actor HTTPClientTransport: Transport { throw MCPError.internalError("Invalid HTTP response") } - // Check response status guard httpResponse.statusCode == 200 else { - // If the server returns 405 Method Not Allowed, - // it indicates that the server doesn't support SSE streaming. - // We should cancel the task instead of retrying the connection. if httpResponse.statusCode == 405 { self.streamingTask?.cancel() } throw MCPError.internalError("HTTP error: \(httpResponse.statusCode)") } - // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: HTTPHeaderName.sessionID) { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { - // Trigger signal on first session ID, though this is unlikely to happen here - // as GET usually follows a POST that would have already set the session ID triggerInitialSessionIDSignal() } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) @@ -623,14 +641,6 @@ public actor HTTPClientTransport: Transport { try await self.processSSE(stream) } - /// Processes an SSE byte stream, extracting events and delivering them. - /// - /// This method processes Server-Sent Events according to the MCP specification, - /// including support for event IDs for resumability. - /// - /// - Parameter stream: The URLSession.AsyncBytes stream to process - /// - Returns: `true` if any data events were received, `false` otherwise. - /// - Throws: Error for stream processing failures @discardableResult private func processSSE(_ stream: URLSession.AsyncBytes) async throws -> Bool { logger.debug("📥 Starting SSE event processing") @@ -640,7 +650,6 @@ public actor HTTPClientTransport: Transport { for try await event in stream.events { eventCount += 1 - // Check if task has been cancelled if Task.isCancelled { logger.debug("⏹️ SSE processing cancelled", metadata: ["eventsProcessed": "\(eventCount)"]) break @@ -654,19 +663,16 @@ public actor HTTPClientTransport: Transport { ] ) - // Store event ID for resumability support if let eventID = event.id, !eventID.isEmpty { self.lastEventID = eventID logger.debug("Stored event ID for resumability", metadata: ["eventID": "\(eventID)"]) } - // Store retry interval if provided by server if let retry = event.retry { self.retryInterval = retry logger.debug("SSE retry interval updated", metadata: ["retryMs": "\(retry)"]) } - // Convert the event data to Data and yield it to the message stream if !event.data.isEmpty, let data = event.data.data(using: .utf8) { hadDataEvent = true messageContinuation.yield(data) diff --git a/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift b/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift index 748dda33..51c7aaed 100644 --- a/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift +++ b/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift @@ -340,6 +340,353 @@ public struct OriginValidator: HTTPRequestValidator { } } +// MARK: - Protected Resource Metadata Validator + +/// Serves the RFC 9728 Protected Resource Metadata document for discovery. +/// +/// Per the MCP authorization specification, servers **MUST** serve Protected Resource +/// Metadata at `/.well-known/oauth-protected-resource` so that clients can discover +/// authorization server endpoints automatically. +/// +/// Place this validator **before** ``BearerTokenValidator`` in the pipeline so that +/// unauthenticated metadata discovery requests succeed. +/// +/// ```swift +/// let prmValidator = ProtectedResourceMetadataValidator( +/// metadata: OAuthProtectedResourceServerMetadata( +/// resource: "https://api.example.com", +/// authorizationServers: [URL(string: "https://auth.example.com")!] +/// ) +/// ) +/// let pipeline = StandardValidationPipeline(validators: [ +/// prmValidator, +/// bearerTokenValidator, +/// // ... +/// ]) +/// ``` +public struct ProtectedResourceMetadataValidator: HTTPRequestValidator { + private let encodedMetadata: Data + + public init(metadata: OAuthProtectedResourceServerMetadata) { + self.encodedMetadata = (try? JSONEncoder().encode(metadata)) ?? Data() + } + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + guard context.httpMethod == "GET", + let path = request.path, + path == OAuthWellKnownPath.protectedResource + || path.hasPrefix("\(OAuthWellKnownPath.protectedResource)/") + else { + return nil + } + return .data(encodedMetadata, headers: [HTTPHeaderName.contentType: ContentType.json]) + } +} + +// MARK: - OAuth Bearer Validator + +/// Result produced by ``BearerTokenValidator`` when validating an access token. +public enum BearerTokenValidationResult: Sendable, Equatable { + /// Access token is valid for this request, with its extracted claims. + /// + /// Supply a ``BearerTokenInfo`` with `audience` and `expiresAt` populated so that + /// ``BearerTokenValidator`` can enforce expiry and audience checks automatically. + /// Pass `BearerTokenInfo()` (all `nil`) to delegate all enforcement to the caller. + case valid(BearerTokenInfo) + + /// Access token is missing required privileges, and new scopes are required. + case insufficientScope(requiredScopes: Set, errorDescription: String? = nil) + + /// Access token is invalid or expired. + case invalidToken(errorDescription: String? = nil) + + /// Authorization request is malformed. + case malformedRequest(errorDescription: String? = nil) +} + +/// Validates OAuth 2.1 Bearer authorization for protected MCP HTTP endpoints. +/// +/// This validator implements resource-server error semantics aligned with the MCP auth spec: +/// - `401` with `WWW-Authenticate: Bearer ...` for missing/invalid tokens +/// - `403` with `error="insufficient_scope"` for insufficient permissions +/// - `400` for malformed authorization requests +/// +/// Include this validator early in your pipeline, before `SessionValidator`, so unauthenticated +/// initialization requests can return a challenge. +/// +/// ## Audience Validation (MUST) +/// +/// Per the MCP authorization specification, **the resource server MUST validate the audience +/// (`aud` claim) of the access token** to ensure it matches the resource server's own identifier. +/// Failure to validate the audience allows token substitution attacks where a token intended +/// for a different resource is replayed against your server. +/// +/// Your ``TokenValidator`` closure **MUST** verify the audience. Example: +/// +/// ```swift +/// let validator = BearerTokenValidator( +/// resourceMetadataURL: metadataURL, +/// tokenValidator: { token, request, context in +/// guard let claims = verifyAndDecode(token) else { +/// return .invalidToken(errorDescription: "Token verification failed") +/// } +/// // MUST: Verify token audience matches this resource server +/// guard claims.audience.contains("https://api.example.com") else { +/// return .invalidToken(errorDescription: "Token audience mismatch") +/// } +/// return .valid +/// } +/// ) +/// ``` +public struct BearerTokenValidator: HTTPRequestValidator { + /// Validates a bearer token and returns token info for audience and expiry enforcement. + public typealias TokenValidator = @Sendable ( + _ token: String, + _ request: HTTPRequest, + _ context: HTTPValidationContext + ) -> BearerTokenValidationResult + + /// Closure that returns the scopes to advertise in `WWW-Authenticate` challenge headers. + /// + /// Return `nil` to omit the `scope` parameter from the challenge. + public typealias ChallengeScopeProvider = @Sendable ( + _ request: HTTPRequest, + _ context: HTTPValidationContext + ) -> Set? + + /// Closure that decides whether a request requires Bearer authentication. + /// + /// Return `false` to allow a request through unauthenticated (e.g., public health-check endpoints). + /// Defaults to requiring authentication on all requests. + public typealias RequirementPredicate = @Sendable ( + _ request: HTTPRequest, + _ context: HTTPValidationContext + ) -> Bool + + public let resourceMetadataURL: URL + public let resourceIdentifier: URL + private let tokenValidator: TokenValidator + private let challengeScopeProvider: ChallengeScopeProvider? + private let requiresAuthentication: RequirementPredicate + private let metadataDiscovery: any OAuthMetadataDiscovering + + /// Creates a `BearerTokenValidator`. + /// + /// - Parameters: + /// - resourceMetadataURL: Included in `WWW-Authenticate` challenge headers as the + /// `resource_metadata` parameter, pointing to the RFC 9728 Protected Resource Metadata document. + /// - resourceIdentifier: The canonical URI of this resource server. Used to validate the + /// `aud` claim in tokens that supply audience information via ``BearerTokenInfo``. + /// - tokenValidator: Validates the Bearer token and returns ``BearerTokenInfo`` with + /// claims for SDK-side expiry and audience enforcement. + /// - challengeScopeProvider: Optional closure supplying scopes to include in challenge headers. + /// - requiresAuthentication: Predicate controlling which requests require a Bearer token. + /// Defaults to requiring authentication on all requests. + /// - metadataDiscovery: Used for audience URL matching. Defaults to ``DefaultOAuthMetadataDiscovery``. + public init( + resourceMetadataURL: URL, + resourceIdentifier: URL, + tokenValidator: @escaping TokenValidator, + challengeScopeProvider: ChallengeScopeProvider? = nil, + requiresAuthentication: @escaping RequirementPredicate = { _, _ in true }, + metadataDiscovery: any OAuthMetadataDiscovering = DefaultOAuthMetadataDiscovery() + ) { + self.resourceMetadataURL = resourceMetadataURL + self.resourceIdentifier = resourceIdentifier + self.tokenValidator = tokenValidator + self.challengeScopeProvider = challengeScopeProvider + self.requiresAuthentication = requiresAuthentication + self.metadataDiscovery = metadataDiscovery + } + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + guard requiresAuthentication(request, context) else { return nil } + + guard let authorizationHeader = request.header(HTTPHeaderName.authorization) else { + return unauthorizedResponse( + challengeScope: challengeScopeProvider?(request, context), + error: nil, + errorDescription: nil, + sessionID: context.sessionID + ) + } + + let parsedToken: String + switch parseBearerToken(from: authorizationHeader) { + case .success(let token): + parsedToken = token + case .failure(let error): + return .error( + statusCode: 400, + .invalidRequest("Bad Request: \(error.message)"), + sessionID: context.sessionID + ) + } + + switch tokenValidator(parsedToken, request, context) { + case .valid(let info): + // Expiry check + if let exp = info.expiresAt, exp <= Date() { + return unauthorizedResponse( + challengeScope: challengeScopeProvider?(request, context), + error: "invalid_token", + errorDescription: "Token has expired", + sessionID: context.sessionID + ) + } + // Audience check — skipped for opaque tokens (audience == nil) + if let audience = info.audience { + let matches = audience.contains { audString in + guard let audURL = URL(string: audString) else { return false } + return metadataDiscovery.protectedResourceMatches( + resource: audURL, endpoint: resourceIdentifier) + } + if !matches { + return unauthorizedResponse( + challengeScope: challengeScopeProvider?(request, context), + error: "invalid_token", + errorDescription: "Token audience mismatch", + sessionID: context.sessionID + ) + } + } + return nil + + case .invalidToken(let errorDescription): + return unauthorizedResponse( + challengeScope: challengeScopeProvider?(request, context), + error: "invalid_token", + errorDescription: errorDescription, + sessionID: context.sessionID + ) + + case .insufficientScope(let requiredScopes, let errorDescription): + return forbiddenInsufficientScopeResponse( + requiredScopes: requiredScopes, + errorDescription: errorDescription, + sessionID: context.sessionID + ) + + case .malformedRequest(let errorDescription): + let message = errorDescription ?? "Malformed authorization request" + return .error( + statusCode: 400, + .invalidRequest("Bad Request: \(message)"), + sessionID: context.sessionID + ) + } + } + + private func unauthorizedResponse( + challengeScope: Set?, + error: String?, + errorDescription: String?, + sessionID: String? + ) -> HTTPResponse { + let challenge = makeBearerChallenge( + resourceMetadataURL: resourceMetadataURL, + scope: challengeScope, + error: error, + errorDescription: errorDescription + ) + return .error( + statusCode: 401, + .invalidRequest("Unauthorized"), + sessionID: sessionID, + extraHeaders: [HTTPHeaderName.wwwAuthenticate: challenge] + ) + } + + private func forbiddenInsufficientScopeResponse( + requiredScopes: Set, + errorDescription: String?, + sessionID: String? + ) -> HTTPResponse { + let challenge = makeBearerChallenge( + resourceMetadataURL: resourceMetadataURL, + scope: requiredScopes, + error: "insufficient_scope", + errorDescription: errorDescription + ) + return .error( + statusCode: 403, + .invalidRequest("Forbidden: Insufficient scope"), + sessionID: sessionID, + extraHeaders: [HTTPHeaderName.wwwAuthenticate: challenge] + ) + } + + private func makeBearerChallenge( + resourceMetadataURL: URL, + scope: Set?, + error: String?, + errorDescription: String? + ) -> String { + var parameters: [String] = [] + parameters.append("resource_metadata=\"\(escapeAuthParameter(resourceMetadataURL.absoluteString))\"") + + if let scope, !scope.isEmpty { + let serializedScope = scope.sorted().joined(separator: " ") + parameters.append("scope=\"\(escapeAuthParameter(serializedScope))\"") + } + + if let error { + parameters.append("error=\"\(escapeAuthParameter(error))\"") + } + + if let errorDescription, !errorDescription.isEmpty { + parameters.append("error_description=\"\(escapeAuthParameter(errorDescription))\"") + } + + return "\(OAuthTokenType.bearer) " + parameters.joined(separator: ", ") + } + + private func escapeAuthParameter(_ value: String) -> String { + value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + } + + private struct BearerTokenParseError: Swift.Error { + let message: String + } + + private func parseBearerToken( + from authorizationHeader: String + ) -> Result { + let trimmed = authorizationHeader.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { + return .failure(.init(message: "Authorization header is empty")) + } + + let parts = trimmed.split( + maxSplits: 1, + whereSeparator: { $0.isWhitespace } + ) + + guard parts.count == 2 else { + return .failure( + .init(message: "Authorization header must be in the form: Bearer ") + ) + } + + guard String(parts[0]).caseInsensitiveCompare(OAuthTokenType.bearer) == .orderedSame else { + return .failure(.init(message: "Authorization scheme must be Bearer")) + } + + let token = String(parts[1]).trimmingCharacters(in: .whitespacesAndNewlines) + guard !token.isEmpty else { + return .failure(.init(message: "Bearer token is empty")) + } + + if token.contains(where: \.isWhitespace) { + return .failure(.init(message: "Bearer token must not contain whitespace")) + } + + return .success(token) + } +} + // MARK: - Validation Pipeline Protocol /// Runs a validation pipeline against an HTTP request. diff --git a/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift b/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift index 7dc369e9..36c517e1 100644 --- a/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift +++ b/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift @@ -42,10 +42,15 @@ public struct HTTPRequest: Sendable { /// The request body data, if any. public let body: Data? - public init(method: String, headers: [String: String] = [:], body: Data? = nil) { + /// The request path (e.g., "/mcp", "/.well-known/oauth-protected-resource"). + /// Used by validators that need to match on specific paths. + public let path: String? + + public init(method: String, headers: [String: String] = [:], body: Data? = nil, path: String? = nil) { self.method = method self.headers = headers self.body = body + self.path = path } /// Case-insensitive header lookup. @@ -139,9 +144,11 @@ public enum HTTPResponse: Sendable { public enum HTTPHeaderName { public static let sessionID = "MCP-Session-Id" public static let protocolVersion = "MCP-Protocol-Version" - public static let lastEventID = "Last-Event-Id" + public static let lastEventID = "Last-Event-ID" public static let accept = "Accept" public static let contentType = "Content-Type" + public static let authorization = "Authorization" + public static let wwwAuthenticate = "WWW-Authenticate" public static let origin = "Origin" public static let host = "Host" public static let cacheControl = "Cache-Control" diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index 57485ade..3eea1240 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -483,7 +483,7 @@ public actor Client { /// A batch of requests. /// /// Objects of this type are passed as an argument to the closure - /// of the ``Client/withBatch(_:)`` method. + /// of the ``Client/withBatch(body:)`` method. public actor Batch { unowned let client: Client var requests: [AnyRequest] = [] diff --git a/Sources/MCPConformance/Client/main.swift b/Sources/MCPConformance/Client/main.swift index 5a090a7b..60b4baef 100644 --- a/Sources/MCPConformance/Client/main.swift +++ b/Sources/MCPConformance/Client/main.swift @@ -14,10 +14,359 @@ import Foundation import Logging import MCP +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + // MARK: - Scenario Handlers typealias ScenarioHandler = ([String]) async throws -> Void +// MARK: - Authorization Scenarios + +private func loadConformanceContext() -> [String: String] { + let env = ProcessInfo.processInfo.environment + + if let raw = env["MCP_CONFORMANCE_CONTEXT"], + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + { + var parsed: [String: String] = [:] + for (key, value) in json { + if let value = value as? String { + parsed[key] = value + } + } + return parsed + } + + var parsed: [String: String] = [:] + if let clientID = env["MCP_CONFORMANCE_CLIENT_ID"] { + parsed["client_id"] = clientID + } + if let clientSecret = env["MCP_CONFORMANCE_CLIENT_SECRET"] { + parsed["client_secret"] = clientSecret + } + return parsed +} + +private func percentEncodeFormValue(_ value: String) -> String { + let allowed = CharacterSet(charactersIn: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") + return value.addingPercentEncoding(withAllowedCharacters: allowed) ?? value +} + +private func formURLEncodedBody(_ parameters: [String: String]) -> Data { + let encoded = parameters + .sorted { $0.key < $1.key } + .map { key, value in + "\(percentEncodeFormValue(key))=\(percentEncodeFormValue(value))" + } + .joined(separator: "&") + return Data(encoded.utf8) +} + +private func clientAssertionAudience(from tokenEndpoint: URL) -> String { + guard var components = URLComponents(url: tokenEndpoint, resolvingAgainstBaseURL: false) else { + return tokenEndpoint.absoluteString + } + + components.query = nil + components.fragment = nil + + var path = components.path + if path.hasSuffix("/token") { + path = String(path.dropLast("/token".count)) + } else { + let parts = path.split(separator: "/") + if !parts.isEmpty { + let parent = parts.dropLast() + path = parent.isEmpty ? "" : "/" + parent.joined(separator: "/") + } + } + if path == "/" { + path = "" + } + components.path = path + + return components.url?.absoluteString ?? tokenEndpoint.absoluteString +} + +private func parsePrivateKeyJWTSigningAlgorithm( + _ signingAlgorithm: String +) throws -> OAuthConfiguration.PrivateKeyJWTSigningAlgorithm { + switch signingAlgorithm.uppercased() { + case OAuthConfiguration.PrivateKeyJWTSigningAlgorithm.ES256.rawValue: + return .ES256 + default: + throw ConformanceError.invalidArguments( + "Unsupported signing algorithm: \(signingAlgorithm)" + ) + } +} + +private func makeOAuthConfiguration( + for scenario: String, + context: [String: String] +) -> OAuthConfiguration { + let clientID = context["client_id"] ?? "test-client" + let clientSecret = context["client_secret"] ?? "test-secret" + + var configuration: OAuthConfiguration + switch scenario { + case "auth/pre-registration": + configuration = .init( + grantType: .authorizationCode, + authentication: .clientSecretBasic( + clientID: clientID, + clientSecret: clientSecret + ) + ) + + case "auth/token-endpoint-auth-basic": + configuration = .init( + grantType: .authorizationCode, + authentication: .clientSecretBasic( + clientID: clientID, + clientSecret: clientSecret + ) + ) + + case "auth/token-endpoint-auth-post": + configuration = .init( + grantType: .authorizationCode, + authentication: .clientSecretPost( + clientID: clientID, + clientSecret: clientSecret + ) + ) + + case "auth/client-credentials-basic": + configuration = .init( + authentication: .clientSecretBasic( + clientID: clientID, + clientSecret: clientSecret + ) + ) + + case "auth/client-credentials-jwt": + let privateKeyPEM = context["private_key_pem"] ?? "" + let signingAlgorithm = context["signing_algorithm"] ?? "ES256" + configuration = .init( + authentication: .privateKeyJWT( + clientID: clientID, + assertionFactory: { tokenEndpoint, clientID in + try OAuthConfiguration.makePrivateKeyJWTAssertion( + clientID: clientID, + tokenEndpoint: tokenEndpoint, + privateKeyPEM: privateKeyPEM, + signingAlgorithm: try parsePrivateKeyJWTSigningAlgorithm(signingAlgorithm), + audience: clientAssertionAudience(from: tokenEndpoint) + ) + } + ) + ) + + case "auth/basic-cimd": + configuration = .init( + grantType: .authorizationCode, + authentication: .none( + clientID: context["client_id"] + ?? "https://conformance-test.local/client-metadata.json") + ) + + case "auth/cross-app-access-complete-flow": + configuration = .init( + authentication: .clientSecretBasic( + clientID: clientID, + clientSecret: clientSecret + ), + accessTokenProvider: makeCrossAppAccessTokenProvider(context: context) + ) + + case let s where s.hasPrefix("auth/client-credentials"): + configuration = .init( + authentication: .none(clientID: clientID) + ) + + default: + configuration = .init( + grantType: .authorizationCode, + authentication: .none(clientID: clientID) + ) + } + + // Conformance harness currently uses loopback http AS endpoints. + configuration.allowLoopbackHTTPAuthorizationServerEndpoints = true + return configuration +} + +private struct ConformanceTokenResponse: Decodable { + let accessToken: String + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + } +} + +private func requestOAuthToken( + url: URL, + parameters: [String: String], + authorizationHeader: String?, + session: URLSession +) async throws -> String { + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + request.setValue("application/json", forHTTPHeaderField: "Accept") + if let authorizationHeader { + request.setValue(authorizationHeader, forHTTPHeaderField: "Authorization") + } + request.httpBody = formURLEncodedBody(parameters) + + let (data, response) = try await session.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw ConformanceError.invalidArguments("Token endpoint returned an invalid response") + } + guard (200..<300).contains(httpResponse.statusCode) else { + let body = String(data: data, encoding: .utf8) ?? "" + throw ConformanceError.invalidArguments( + "Token endpoint error (\(httpResponse.statusCode)): \(body)" + ) + } + + let token = try JSONDecoder().decode(ConformanceTokenResponse.self, from: data) + guard !token.accessToken.isEmpty else { + throw ConformanceError.invalidArguments("Token endpoint returned an empty access token") + } + return token.accessToken +} + +private func makeCrossAppAccessTokenProvider( + context: [String: String] +) -> OAuthConfiguration.AccessTokenProvider { + return { discovery, session in + guard let clientID = context["client_id"], + let clientSecret = context["client_secret"], + let idpClientID = context["idp_client_id"], + let idpIDToken = context["idp_id_token"], + let idpTokenEndpointValue = context["idp_token_endpoint"], + let idpTokenEndpoint = URL(string: idpTokenEndpointValue) + else { + throw ConformanceError.invalidArguments( + "Cross-app scenario requires client_id, client_secret, idp_client_id, idp_id_token, and idp_token_endpoint" + ) + } + + guard let authorizationServer = discovery.authorizationServer else { + throw ConformanceError.invalidArguments( + "SDK did not provide authorization server discovery context" + ) + } + guard let tokenEndpoint = discovery.tokenEndpoint else { + throw ConformanceError.invalidArguments( + "SDK did not provide token endpoint discovery context" + ) + } + let resource = discovery.resource.absoluteString + + let idJag = try await requestOAuthToken( + url: idpTokenEndpoint, + parameters: [ + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": idpIDToken, + "subject_token_type": "urn:ietf:params:oauth:token-type:id_token", + "requested_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "audience": authorizationServer.absoluteString, + "resource": resource, + "client_id": idpClientID, + ], + authorizationHeader: nil, + session: session + ) + + var accessTokenParameters: [String: String] = [ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": idJag, + "resource": resource, + ] + if let requestedScopes = discovery.requestedScopes, !requestedScopes.isEmpty { + accessTokenParameters["scope"] = requestedScopes.sorted().joined(separator: " ") + } + + let basicCredentials = Data("\(clientID):\(clientSecret)".utf8).base64EncodedString() + let accessToken = try await requestOAuthToken( + url: tokenEndpoint, + parameters: accessTokenParameters, + authorizationHeader: "Basic \(basicCredentials)", + session: session + ) + + return accessToken + } +} + +func runAuthorizationScenario(scenario: String, args: [String]) async throws { + var logger = Logger( + label: "mcp.conformance.client.auth", + factory: { StreamLogHandler.standardError(label: $0) } + ) + logger.logLevel = .debug + + logger.debug("Starting auth scenario", metadata: ["scenario": "\(scenario)"]) + + guard let serverURLString = args.last, + let serverURL = URL(string: serverURLString) + else { + throw ConformanceError.invalidArguments("Valid server URL is required") + } + + let context = loadConformanceContext() + let oauthConfig = makeOAuthConfiguration(for: scenario, context: context) + + let transport = HTTPClientTransport( + endpoint: serverURL, + streaming: true, + authorizer: OAuthAuthorizer(configuration: oauthConfig), + logger: logger + ) + + let client = Client(name: "test-client", version: "1.0.0") + + // Scenarios that expect the connection to fail with a specific error. + if scenario == "auth/resource-mismatch" { + do { + _ = try await client.connect(transport: transport) + throw ConformanceError.invalidArguments( + "Expected authorization to fail with resource mismatch, but connection succeeded" + ) + } catch let error as MCPError { + guard case .internalError(let detail) = error, + detail?.contains("resource mismatch") == true + else { + throw ConformanceError.invalidArguments( + "Connection failed, but not due to resource mismatch: \(error.localizedDescription)" + ) + } + logger.debug("Client correctly rejected mismatched PRM resource") + } + return + } + + _ = try await client.connect(transport: transport) + + // Exercise both initialization and regular request paths. + let (tools, _) = try await client.listTools() + logger.debug("Auth scenario listed tools", metadata: ["count": "\(tools.count)"]) + + // Trigger an additional request for scenarios that involve runtime scope behavior. + if scenario.contains("scope"), let firstTool = tools.first { + _ = try? await client.callTool(name: firstTool.name, arguments: [:]) + } + + await client.disconnect() + logger.debug("Auth scenario completed", metadata: ["scenario": "\(scenario)"]) +} + // MARK: - Basic Scenarios /// Basic client that connects, initializes, and lists tools @@ -315,7 +664,6 @@ nonisolated(unsafe) let scenarioHandlers: [String: ScenarioHandler] = [ "tools_call": runToolsCallScenario, "sse-retry": runSSEScenario, "elicitation-sep1034-client-defaults": runElicitationSEP1034ClientDefaults, - // Note: Other scenarios (auth/*) will use the default handler ] // MARK: - Error Types @@ -353,11 +701,16 @@ struct ConformanceClient { Foundation.exit(1) } - // Get handler for scenario, or use default if not implemented - let handler = scenarioHandlers[scenario] ?? runDefaultScenario - - // Log if using default handler - if scenarioHandlers[scenario] == nil { + // Get handler for scenario + let handler: ScenarioHandler + if let explicitHandler = scenarioHandlers[scenario] { + handler = explicitHandler + } else if scenario.hasPrefix("auth/") { + handler = { args in + try await runAuthorizationScenario(scenario: scenario, args: args) + } + } else { + handler = runDefaultScenario var stderr = StandardError() print("⚠️ Scenario '\(scenario)' not fully implemented - using default handler", to: &stderr) } diff --git a/Sources/MCPConformance/Server/HTTPApp.swift b/Sources/MCPConformance/Server/HTTPApp.swift index ddd281cc..6edba56f 100644 --- a/Sources/MCPConformance/Server/HTTPApp.swift +++ b/Sources/MCPConformance/Server/HTTPApp.swift @@ -352,10 +352,13 @@ private final class HTTPHandler: ChannelInboundHandler, @unchecked Sendable { body = nil } + let path = String(state.head.uri.split(separator: "?").first ?? Substring(state.head.uri)) + return HTTPRequest( method: state.head.method.rawValue, headers: headers, - body: body + body: body, + path: path ) } diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index c2701337..bd8c38f7 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -36,17 +36,19 @@ import Testing // MARK: - Mock Handler Registry Actor actor RequestHandlerStorage { - private var requestHandler: + var requestHandler: (@Sendable (URLRequest) async throws -> (HTTPURLResponse, Data))? + private var callCounts: [URL: Int] = [:] func setHandler( _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) - ) async { + ) { requestHandler = handler } - func clearHandler() async { + func clearHandler() { requestHandler = nil + callCounts = [:] } func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { @@ -57,8 +59,15 @@ import Testing NSLocalizedDescriptionKey: "No request handler set" ]) } + if let url = request.url { + callCounts[url, default: 0] += 1 + } return try await handler(request) } + + func callCount(for url: URL) -> Int { + callCounts[url, default: 0] + } } // MARK: - Helper Methods @@ -125,13 +134,27 @@ import Testing } override func stopLoading() {} + + static func verifyCallCounts( + _ expected: [URL: Int], + sourceLocation: SourceLocation = #_sourceLocation + ) async { + for (url, expectedCount) in expected { + let actual = await requestHandlerStorage.callCount(for: url) + #expect( + actual == expectedCount, + "Expected \(expectedCount) call(s) to \(url.lastPathComponent), got \(actual)", + sourceLocation: sourceLocation + ) + } + } } // MARK: - @Suite("HTTP Client Transport Tests", .serialized) struct HTTPClientTransportTests { - let testEndpoint = URL(string: "http://localhost:8080/test")! + let testEndpoint = URL(string: "https://localhost:8080/test")! @Test("Connect and Disconnect", .httpClientTransportSetup) func testConnectAndDisconnect() async throws { @@ -350,6 +373,551 @@ import Testing } } + @Test("HTTP 400 Bad Request Error", .httpClientTransportSetup) + func testHTTPBadRequestError() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let messageData = #"{"jsonrpc":"2.0","method":"test","id":40}"#.data(using: .utf8)! + + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint] (_: URLRequest) in + let response = HTTPURLResponse( + url: testEndpoint, statusCode: 400, httpVersion: "HTTP/1.1", headerFields: nil)! + return (response, Data("Bad Request".utf8)) + } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + + do { + try await transport.send(messageData) + Issue.record("Expected send to throw an error for 400") + } catch let error as MCPError { + guard case .internalError(let message) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(message?.contains("Bad request") ?? false) + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + } + + @Test("HTTP 401 Unauthorized Error Without OAuth", .httpClientTransportSetup) + func testHTTPUnauthorizedErrorWithoutOAuth() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let messageData = #"{"jsonrpc":"2.0","method":"test","id":41}"#.data(using: .utf8)! + + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint] (_: URLRequest) in + let response = HTTPURLResponse( + url: testEndpoint, + statusCode: 401, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": "Bearer scope=\"files:read\"" + ])! + return (response, Data()) + } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + + do { + try await transport.send(messageData) + Issue.record("Expected send to throw an error for 401 without OAuth") + } catch let error as MCPError { + guard case .internalError(let message) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(message?.contains("Authentication required") ?? false) + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + } + + @Test("HTTP 403 Forbidden Error Without OAuth", .httpClientTransportSetup) + func testHTTPForbiddenErrorWithoutOAuth() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let messageData = #"{"jsonrpc":"2.0","method":"test","id":42}"#.data(using: .utf8)! + + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint] (_: URLRequest) in + let response = HTTPURLResponse( + url: testEndpoint, + statusCode: 403, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"files:write\"" + ])! + return (response, Data()) + } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + + do { + try await transport.send(messageData) + Issue.record("Expected send to throw an error for 403 without OAuth") + } catch let error as MCPError { + guard case .internalError(let message) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(message?.contains("Access forbidden") ?? false) + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + } + + @Test("OAuth scope step-up retries after 403 insufficient_scope", .httpClientTransportSetup) + func testOAuthStepUpRetryAfter403InsufficientScope() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let testEndpoint = URL(string: "https://localhost:8080/step-up")! + let resourceMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/step-up")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = #"{"jsonrpc":"2.0","result":{"ok":true},"id":43}"#.data( + using: .utf8)! + + actor CallTracker { + var tokenCalls = 0 + + func nextTokenCall() -> Int { + tokenCalls += 1 + return tokenCalls + } + } + let tracker = CallTracker() + + await MockURLProtocol.requestHandlerStorage.setHandler { + [tracker, testEndpoint, resourceMetadataURL, asMetadataURL, tokenEndpointURL, finalResponseData] request in + guard let url = request.url else { + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [NSLocalizedDescriptionKey: "Missing request URL"]) + } + + switch url { + case testEndpoint: + switch request.value(forHTTPHeaderField: "Authorization") { + case nil: + let response = HTTPURLResponse( + url: url, + statusCode: 401, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer resource_metadata=\"\(resourceMetadataURL.absoluteString)\", scope=\"files:read\"" + ])! + return (response, Data()) + + case "Bearer access-token-read": + let response = HTTPURLResponse( + url: url, + statusCode: 403, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"files:write\", resource_metadata=\"\(resourceMetadataURL.absoluteString)\", error_description=\"Additional file write permission required\"" + ])! + return (response, Data()) + + case "Bearer access-token-read-write": + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, finalResponseData) + + default: + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected Authorization value: \(request.value(forHTTPHeaderField: "Authorization") ?? "")" + ]) + } + + case resourceMetadataURL: + let metadata = + #"{ "authorization_servers": ["https://localhost:8080/auth"], "scopes_supported": ["files:read","files:write"] }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, metadata) + + case asMetadataURL: + let metadata = #"{ "issuer": "https://localhost:8080/auth", "token_endpoint": "https://localhost:8080/oauth/token" }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, metadata) + + case tokenEndpointURL: + let tokenCall = await tracker.nextTokenCall() + let body = String(data: request.readBody() ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect(body.contains("client_id=test-client")) + #expect(body.contains("resource=https%3A%2F%2Flocalhost%3A8080%2Fstep-up")) + + if tokenCall == 1 { + #expect(body.contains("scope=files%3Aread")) + let tokenResponse = + #"{ "access_token": "access-token-read", "token_type": "Bearer", "expires_in": 3600 }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, tokenResponse) + } + + #expect(tokenCall == 2) + #expect(body.contains("scope=files%3Aread%20files%3Awrite")) + let tokenResponse = + #"{ "access_token": "access-token-read-write", "token_type": "Bearer", "expires_in": 3600 }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, tokenResponse) + + default: + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected URL: \(url.absoluteString)" + ]) + } + } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + authorizer: OAuthAuthorizer(configuration: .init(authentication: .none(clientID: "test-client"))), + logger: nil + ) + + try await transport.connect() + let messageData = #"{"jsonrpc":"2.0","method":"ping","id":43}"#.data(using: .utf8)! + try await transport.send(messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == finalResponseData) + #expect(await tracker.tokenCalls == 2) + + await transport.disconnect() + } + + @Test("OAuth scope upgrade tracking is scoped per operation", .httpClientTransportSetup) + func testOAuthScopeUpgradeTrackingPerOperation() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let testEndpoint = URL(string: "https://localhost:8080/operation-tracking")! + let resourceMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/operation-tracking")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = #"{"jsonrpc":"2.0","result":{"ok":true},"id":62}"#.data( + using: .utf8)! + + actor CallTracker { + var tokenCalls = 0 + var opAForbiddenCalls = 0 + var opBForbiddenCalls = 0 + + func nextTokenCall() -> Int { + tokenCalls += 1 + return tokenCalls + } + + func incrementOpAForbiddenCalls() { + opAForbiddenCalls += 1 + } + + func incrementOpBForbiddenCalls() { + opBForbiddenCalls += 1 + } + } + let tracker = CallTracker() + + await MockURLProtocol.requestHandlerStorage.setHandler { + [tracker, testEndpoint, resourceMetadataURL, asMetadataURL, tokenEndpointURL, finalResponseData] request in + guard let url = request.url else { + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [NSLocalizedDescriptionKey: "Missing request URL"]) + } + + switch url { + case testEndpoint: + let body = String(data: request.readBody() ?? Data(), encoding: .utf8) ?? "" + let isOperationA = body.contains(#""method":"tools/callA""#) + let isOperationB = body.contains(#""method":"tools/callB""#) + let authorization = request.value(forHTTPHeaderField: "Authorization") + + if authorization == nil { + let response = HTTPURLResponse( + url: url, + statusCode: 401, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer resource_metadata=\"\(resourceMetadataURL.absoluteString)\", scope=\"files:read\"" + ])! + return (response, Data()) + } + + if isOperationA { + if authorization == "Bearer access-token-read" + || authorization == "Bearer access-token-read-write" + { + await tracker.incrementOpAForbiddenCalls() + let response = HTTPURLResponse( + url: url, + statusCode: 403, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"files:write\", resource_metadata=\"\(resourceMetadataURL.absoluteString)\"" + ])! + return (response, Data()) + } + + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected Authorization for opA: \(authorization ?? "")" + ]) + } + + if isOperationB { + if authorization == "Bearer access-token-read-write" { + await tracker.incrementOpBForbiddenCalls() + let response = HTTPURLResponse( + url: url, + statusCode: 403, + httpVersion: "HTTP/1.1", + headerFields: [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"files:write\", resource_metadata=\"\(resourceMetadataURL.absoluteString)\"" + ])! + return (response, Data()) + } + + if authorization == "Bearer access-token-opb" { + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, finalResponseData) + } + + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected Authorization for opB: \(authorization ?? "")" + ]) + } + + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected request body: \(body)" + ]) + + case resourceMetadataURL: + let metadata = + #"{ "authorization_servers": ["https://localhost:8080/auth"], "scopes_supported": ["files:read","files:write"] }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, metadata) + + case asMetadataURL: + let metadata = #"{ "issuer": "https://localhost:8080/auth", "token_endpoint": "https://localhost:8080/oauth/token" }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, metadata) + + case tokenEndpointURL: + let tokenCall = await tracker.nextTokenCall() + let body = String(data: request.readBody() ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect(body.contains("client_id=test-client")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Foperation-tracking") + ) + + switch tokenCall { + case 1: + #expect(body.contains("scope=files%3Aread")) + let tokenResponse = + #"{ "access_token": "access-token-read", "token_type": "Bearer", "expires_in": 3600 }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, tokenResponse) + + case 2: + #expect(body.contains("scope=files%3Aread%20files%3Awrite")) + let tokenResponse = + #"{ "access_token": "access-token-read-write", "token_type": "Bearer", "expires_in": 3600 }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, tokenResponse) + + case 3: + #expect(body.contains("scope=files%3Aread%20files%3Awrite")) + let tokenResponse = + #"{ "access_token": "access-token-opb", "token_type": "Bearer", "expires_in": 3600 }"# + .data(using: .utf8)! + let response = HTTPURLResponse( + url: url, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, tokenResponse) + + default: + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected token call count: \(tokenCall)" + ]) + } + + default: + throw NSError( + domain: "MockURLProtocolError", + code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected URL: \(url.absoluteString)" + ]) + } + } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + authorizer: OAuthAuthorizer(configuration: .init( + authentication: .none(clientID: "test-client"), + retryPolicy: .init(maxAuthorizationAttempts: 8, maxScopeUpgradeAttempts: 1) + )), + logger: nil + ) + + try await transport.connect() + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + let operationAData = #"{"jsonrpc":"2.0","method":"tools/callA","id":61}"#.data( + using: .utf8)! + do { + try await transport.send(operationAData) + Issue.record("Expected operation A to fail after scope-upgrade retry limit") + } catch let error as MCPError { + guard case .internalError(let message) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(message?.contains("Access forbidden") ?? false) + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + let operationBData = #"{"jsonrpc":"2.0","method":"tools/callB","id":62}"#.data( + using: .utf8)! + try await transport.send(operationBData) + + let received = try await iterator.next() + #expect(received == finalResponseData) + + #expect(await tracker.tokenCalls == 3) + #expect(await tracker.opAForbiddenCalls == 2) + #expect(await tracker.opBForbiddenCalls == 1) + + await transport.disconnect() + } + @Test("Session Expired Error (404 with Session ID)", .httpClientTransportSetup) func testSessionExpiredError() async throws { let configuration = URLSessionConfiguration.ephemeral @@ -725,6 +1293,1073 @@ import Testing await transport.disconnect() } + @Test("OAuth client credentials performs discovery and retries after 401", .httpClientTransportSetup) + func testOAuthClientCredentialsRetryAfter401() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthClientCredentialsRetryAfter401() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth scope selection falls back to scopes_supported when challenge scope is absent", .httpClientTransportSetup) + func testOAuthScopeSelectionFallsBackToScopesSupportedWhenChallengeScopeMissing() + async throws + { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthScopeSelectionFallsBackToScopesSupported() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth omits scope parameter when challenge scope and scopes_supported are unavailable", .httpClientTransportSetup) + func testOAuthScopeSelectionOmitsScopeWhenNoHintsAvailable() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthScopeOmittedWhenNoHints() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth includes canonical resource in both authorization and token requests", .httpClientTransportSetup) + func testOAuthResourceParameterIncludedInAuthorizationAndTokenRequests() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthResourceParameterInAuthorizationAndToken() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth rejects authorization when AS metadata omits code_challenge_methods_supported", + .httpClientTransportSetup + ) + func testOAuthRejectsAuthorizationWithoutPKCEMetadata() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsAuthorizationWithoutPKCEMetadata() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth rejects authorization when AS metadata lacks PKCE S256 support", + .httpClientTransportSetup + ) + func testOAuthRejectsAuthorizationWithoutS256PKCE() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsAuthorizationWithoutS256PKCE() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth rejects authorization response redirect URI mismatch", + .httpClientTransportSetup + ) + func testOAuthRejectsAuthorizationResponseRedirectMismatch() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsAuthorizationResponseRedirectMismatch() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth rejects authorization response state mismatch", + .httpClientTransportSetup + ) + func testOAuthRejectsAuthorizationResponseStateMismatch() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsAuthorizationResponseStateMismatch() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth sends access token only via Authorization header", .httpClientTransportSetup) + func testOAuthDoesNotSendAccessTokenInBodyOrQuery() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthAccessTokenOnlyViaAuthorizationHeader() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth sends Bearer Authorization header on every request in a logical session", .httpClientTransportSetup) + func testOAuthUsesAuthorizationHeaderForEveryRequestInSession() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthAuthorizationHeaderForEveryRequestInSession() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + try await transport.send(scenario.messageData) + let receivedFirst = try await iterator.next() + #expect(receivedFirst == scenario.expectedResponseData) + + try await transport.send(scenario.secondMessageData!) + let receivedSecond = try await iterator.next() + #expect(receivedSecond == scenario.secondExpectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth streaming GET requests use Authorization header and not query token", .httpClientTransportSetup) + func testOAuthStreamingGETUsesAuthorizationHeaderOnly() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthStreamingGETUsesAuthorizationHeaderOnly() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + sseInitializationTimeout: scenario.sseInitializationTimeout!, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await transport.disconnect() + } + + @Test("OAuth rejects non-Bearer token_type for MCP resource requests", .httpClientTransportSetup) + func testOAuthRejectsNonBearerTokenType() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsNonBearerTokenType() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth token endpoint failures redact raw response body", .httpClientTransportSetup) + func testOAuthTokenEndpointFailureRedactsResponseBody() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthTokenEndpointFailureRedactsResponseBody() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth rejects non-HTTPS token endpoint from AS metadata", .httpClientTransportSetup) + func testOAuthRejectsNonHTTPSTokenEndpoint() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthRejectsNonHTTPSTokenEndpoint() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth allows loopback http authorization server endpoints when explicitly enabled", + .httpClientTransportSetup + ) + func testOAuthAllowsLoopbackHTTPAuthorizationServerEndpointsWhenEnabled() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage.configureOAuthAllowsLoopbackHTTPAuthorizationServerEndpoints() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth accessTokenProvider receives SDK discovery context", .httpClientTransportSetup) + func testOAuthAccessTokenProviderReceivesDiscoveryContext() async throws { + let (scenario, providerTracker) = await MockURLProtocol.requestHandlerStorage + .configureOAuthAccessTokenProviderReceivesDiscoveryContext() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + let capturedContext = await providerTracker.capturedContext + #expect(capturedContext != nil) + #expect(capturedContext?.statusCode == 401) + #expect(capturedContext?.endpoint == scenario.testEndpoint) + #expect(capturedContext?.resource == scenario.testEndpoint) + #expect(capturedContext?.authorizationServer == URL(string: "https://localhost:8080/auth")) + #expect(capturedContext?.tokenEndpoint == URL(string: "https://localhost:8080/oauth/token")) + #expect(capturedContext?.challengedScope == "files:read files:write") + #expect(Set(capturedContext?.scopesSupported ?? []) == Set(["files:read", "files:write"])) + #expect(capturedContext?.requestedScopes == Set(["files:read", "files:write"])) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth discovery uses resource_metadata URL from WWW-Authenticate when present", .httpClientTransportSetup) + func testOAuthDiscoveryUsesHeaderResourceMetadataWhenPresent() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthDiscoveryUsesHeaderResourceMetadata() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth discovery falls back to well-known metadata URLs in required order", .httpClientTransportSetup) + func testOAuthDiscoveryFallbackWellKnownOrder() async throws { + let (scenario, tracker) = await MockURLProtocol.requestHandlerStorage + .configureOAuthDiscoveryFallbackWellKnownOrder() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + let metadataRequests = await tracker.requests + let fallbackPathMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let fallbackRootMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource")! + #expect(metadataRequests == [fallbackPathMetadataURL, fallbackRootMetadataURL]) + + await transport.disconnect() + } + + @Test("OAuth discovery fails when protected resource metadata is unavailable", .httpClientTransportSetup) + func testOAuthDiscoveryFailsWhenMetadataUnavailable() async throws { + let (scenario, tracker) = await MockURLProtocol.requestHandlerStorage + .configureOAuthDiscoveryFailsWhenMetadataUnavailable() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to fail when PRM discovery fails") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + let protectedResourceMetadataRequests = await tracker.requests + let fallbackPathMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let fallbackRootMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource")! + #expect(protectedResourceMetadataRequests == [fallbackPathMetadataURL, fallbackRootMetadataURL]) + + await transport.disconnect() + } + + @Test("OAuth authorization server metadata discovery tries path issuer URLs in RFC order", .httpClientTransportSetup) + func testOAuthAuthorizationServerMetadataDiscoveryOrderForPathIssuer() async throws { + let (scenario, tracker) = await MockURLProtocol.requestHandlerStorage + .configureOAuthASMetadataDiscoveryOrderForPathIssuer() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + let requests = await tracker.requests + let asMetadataOAuthInsertedURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/tenant1")! + let asMetadataOIDCInsertedURL = URL( + string: "https://localhost:8080/.well-known/openid-configuration/tenant1")! + let asMetadataOIDCAppendedURL = URL( + string: "https://localhost:8080/tenant1/.well-known/openid-configuration")! + #expect(requests == [asMetadataOAuthInsertedURL, asMetadataOIDCInsertedURL, asMetadataOIDCAppendedURL]) + + await transport.disconnect() + } + + @Test("OAuth authorization server metadata discovery tries root issuer URLs in RFC order", .httpClientTransportSetup) + func testOAuthAuthorizationServerMetadataDiscoveryOrderForRootIssuer() async throws { + let (scenario, tracker) = await MockURLProtocol.requestHandlerStorage + .configureOAuthASMetadataDiscoveryOrderForRootIssuer() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + let requests = await tracker.requests + let asMetadataOAuthURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server")! + let asMetadataOIDCURL = URL( + string: "https://localhost:8080/.well-known/openid-configuration")! + #expect(requests == [asMetadataOAuthURL, asMetadataOIDCURL]) + + await transport.disconnect() + } + + @Test("OAuth registration prefers CIMD when AS advertises support", .httpClientTransportSetup) + func testOAuthRegistrationPrefersCIMDWhenAdvertised() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRegistrationPrefersCIMDWhenAdvertised() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth pre-registration uses static client credentials without dynamic registration", .httpClientTransportSetup) + func testOAuthPreRegistrationUsesStaticCredentialsWithoutDynamicRegistration() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthPreRegistrationUsesStaticCredentials() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth registration falls back to dynamic registration when CIMD is not advertised", .httpClientTransportSetup) + func testOAuthRegistrationFallsBackToDynamicRegistrationWhenCIMDNotAdvertised() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRegistrationFallsBackToDynamicRegistrationCIMDNotAdvertised() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth registration falls back to dynamic registration when CIMD capability is absent", .httpClientTransportSetup) + func testOAuthRegistrationFallsBackToDynamicRegistrationWhenCIMDCapabilityMissing() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRegistrationFallsBackToDynamicRegistrationCIMDCapabilityMissing() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth registration surfaces actionable error when no supported mechanism is available", .httpClientTransportSetup) + func testOAuthRegistrationMissingMechanismReturnsActionableError() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRegistrationMissingMechanismReturnsActionableError() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth CIMD rejects non-HTTPS client_id URL when AS advertises support", .httpClientTransportSetup) + func testOAuthCIMDRejectsNonHTTPSClientIDURL() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthCIMDRejectsNonHTTPSClientIDURL() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth rejects insecure MCP endpoint URL", .httpClientTransportSetup) + func testOAuthRejectsInsecureMCPEndpointURL() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRejectsInsecureMCPEndpointURL() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth rejects non-loopback http redirect URI", .httpClientTransportSetup) + func testOAuthRejectsNonLoopbackHTTPRedirectURI() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthRejectsNonLoopbackHTTPRedirectURI() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + do { + try await transport.send(scenario.messageData) + Issue.record("Expected send to throw an error") + } catch let error as MCPError { + guard case .internalError(let detail) = error else { + Issue.record("Expected MCPError.internalError, got \(error)") + throw error + } + #expect(detail?.contains(scenario.expectedErrorSubstring!) == true) + for unexpected in scenario.unexpectedErrorSubstrings { + #expect(detail?.contains(unexpected) == false) + } + } catch { + Issue.record("Expected MCPError, got \(error)") + throw error + } + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test("OAuth PRM cache is invalidated when resource_metadata URL changes between challenges", .httpClientTransportSetup) + func testOAuthPRMCacheInvalidatedOnResourceMetadataURLChange() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthPRMCacheInvalidatedOnResourceMetadataURLChange() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + try await transport.send(scenario.messageData) + let receivedFirst = try await iterator.next() + #expect(receivedFirst == scenario.expectedResponseData) + + try await transport.send(scenario.secondMessageData!) + let receivedSecond = try await iterator.next() + #expect(receivedSecond == scenario.secondExpectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth token request uses PRM resource field as resource indicator", + .httpClientTransportSetup) + func testOAuthResourceUsesPRMResourceField() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthResourceUsesPRMResourceField() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth tries second authorization server when first returns no metadata", + .httpClientTransportSetup) + func testOAuthSecondAuthorizationServerTriedWhenFirstFails() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthSecondAuthorizationServerTriedWhenFirstFails() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth re-registers client after client_secret_expires_at has passed", + .httpClientTransportSetup) + func testOAuthReRegistersAfterClientSecretExpiry() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthReRegistersAfterClientSecretExpiry() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + try await transport.send(scenario.messageData) + let receivedFirst = try await iterator.next() + #expect(receivedFirst == scenario.expectedResponseData) + + try await transport.send(scenario.secondMessageData!) + let receivedSecond = try await iterator.next() + #expect(receivedSecond == scenario.secondExpectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth skips AS metadata with wrong issuer and uses next URL variant", + .httpClientTransportSetup) + func testOAuthIssuerMismatchTriesNextURLVariant() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthIssuerMismatchTriesNextURLVariant() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: scenario.streaming, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + try await transport.send(scenario.messageData) + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + let received = try await iterator.next() + #expect(received == scenario.expectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + + @Test( + "OAuth proactively refreshes token when within proactive refresh window", + .httpClientTransportSetup) + func testOAuthProactiveTokenRefreshWithinWindow() async throws { + let scenario = await MockURLProtocol.requestHandlerStorage + .configureOAuthProactiveTokenRefreshWithinWindow() + + let transport = HTTPClientTransport( + endpoint: scenario.testEndpoint, + configuration: MockResponses.ephemeralConfiguration(), + streaming: false, + authorizer: OAuthAuthorizer(configuration: scenario.oauthConfiguration), + logger: nil + ) + + try await transport.connect() + + let stream = await transport.receive() + var iterator = stream.makeAsyncIterator() + + try await transport.send(scenario.messageData) + let receivedFirst = try await iterator.next() + #expect(receivedFirst == scenario.expectedResponseData) + + try await transport.send(scenario.secondMessageData!) + let receivedSecond = try await iterator.next() + #expect(receivedSecond == scenario.secondExpectedResponseData) + + await MockURLProtocol.verifyCallCounts(scenario.expectedCallCounts) + await transport.disconnect() + } + @Test("Send With Protocol Version Header", .httpClientTransportSetup) func testProtocolVersionHeader() async throws { let configuration = URLSessionConfiguration.ephemeral diff --git a/Tests/MCPTests/HTTPServerTransportTests.swift b/Tests/MCPTests/HTTPServerTransportTests.swift index d92b6973..31275343 100644 --- a/Tests/MCPTests/HTTPServerTransportTests.swift +++ b/Tests/MCPTests/HTTPServerTransportTests.swift @@ -48,7 +48,11 @@ private func makeResponseBody(id: String = "2") -> Data { return try! JSONSerialization.data(withJSONObject: json) } -private func makeStatefulPOSTRequest(body: Data, sessionID: String? = nil) -> HTTPRequest { +private func makeStatefulPOSTRequest( + body: Data, + sessionID: String? = nil, + authorization: String? = nil +) -> HTTPRequest { var headers: [String: String] = [ "Content-Type": "application/json", "Accept": "application/json, text/event-stream", @@ -56,6 +60,9 @@ private func makeStatefulPOSTRequest(body: Data, sessionID: String? = nil) -> HT if let sessionID { headers["Mcp-Session-Id"] = sessionID } + if let authorization { + headers[HTTPHeaderName.authorization] = authorization + } return HTTPRequest(method: "POST", headers: headers, body: body) } @@ -65,7 +72,7 @@ private func makeGETRequest(sessionID: String, lastEventID: String? = nil) -> HT "Mcp-Session-Id": sessionID, ] if let lastEventID { - headers["Last-Event-Id"] = lastEventID + headers["Last-Event-ID"] = lastEventID } return HTTPRequest(method: "GET", headers: headers) } @@ -97,6 +104,28 @@ private func makeStatefulTransport( ) } +private let authResourceMetadataURL = + URL(string: "https://mcp.example.com/.well-known/oauth-protected-resource/mcp")! +private let authResourceIdentifier = URL(string: "https://mcp.example.com/mcp")! + +private func makeAuthenticatedStatefulTransport( + challengeScopes: Set? = nil, + tokenValidator: @escaping BearerTokenValidator.TokenValidator +) -> StatefulHTTPServerTransport { + let validator = BearerTokenValidator( + resourceMetadataURL: authResourceMetadataURL, + resourceIdentifier: authResourceIdentifier, + tokenValidator: tokenValidator, + challengeScopeProvider: { _, _ in challengeScopes } + ) + return StatefulHTTPServerTransport( + validationPipeline: StandardValidationPipeline(validators: [ + validator, + SessionValidator(), + ]) + ) +} + private func makeStatelessTransport() -> StatelessHTTPServerTransport { StatelessHTTPServerTransport( validationPipeline: StandardValidationPipeline(validators: []) @@ -628,6 +657,132 @@ struct StatefulHTTPServerTransportTests { await transport.disconnect() } + // MARK: - OAuth Bearer Validation + + @Test("Bearer auth validator returns 401 with challenge when authorization is missing") + func testBearerAuthValidatorMissingAuthorizationReturns401() async throws { + let transport = makeAuthenticatedStatefulTransport( + challengeScopes: ["files:read"] + ) { _, _, _ in + .valid(BearerTokenInfo()) + } + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeInitializeBody()) + ) + + #expect(response.statusCode == 401) + let challenge = response.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("Bearer ") == true) + #expect( + challenge?.contains("resource_metadata=\"\(authResourceMetadataURL.absoluteString)\"") + == true + ) + #expect(challenge?.contains("scope=\"files:read\"") == true) + + await transport.disconnect() + } + + @Test("Bearer auth validator returns 400 when authorization header is malformed") + func testBearerAuthValidatorMalformedAuthorizationReturns400() async throws { + let transport = makeAuthenticatedStatefulTransport { _, _, _ in .valid(BearerTokenInfo()) } + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeInitializeBody(), + authorization: "Basic dGVzdA==" + ) + ) + + #expect(response.statusCode == 400) + #expect(response.headers[HTTPHeaderName.wwwAuthenticate] == nil) + + await transport.disconnect() + } + + @Test("Bearer auth validator returns 401 invalid_token for rejected token") + func testBearerAuthValidatorInvalidTokenReturns401() async throws { + let transport = makeAuthenticatedStatefulTransport { _, _, _ in + .invalidToken(errorDescription: "Token expired") + } + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeInitializeBody(), + authorization: "Bearer expired-token" + ) + ) + + #expect(response.statusCode == 401) + let challenge = response.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("error=\"invalid_token\"") == true) + #expect(challenge?.contains("error_description=\"Token expired\"") == true) + #expect( + challenge?.contains("resource_metadata=\"\(authResourceMetadataURL.absoluteString)\"") + == true + ) + + await transport.disconnect() + } + + @Test("Bearer auth validator returns 403 insufficient_scope with scope challenge") + func testBearerAuthValidatorInsufficientScopeReturns403() async throws { + let transport = makeAuthenticatedStatefulTransport { token, _, context in + if context.isInitializationRequest, token == "init-token" { + return .valid(BearerTokenInfo()) + } + if token == "read-only-token" { + return .insufficientScope( + requiredScopes: ["files:read", "files:write"], + errorDescription: "Additional file write permission required" + ) + } + return .invalidToken(errorDescription: "Unknown token") + } + try await transport.connect() + + let initResponse = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeInitializeBody(), + authorization: "Bearer init-token" + ) + ) + let sessionID = initResponse.headers[HTTPHeaderName.sessionID] + #expect(sessionID != nil) + + guard let sessionID else { + await transport.disconnect() + return + } + + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeNotificationBody(), + sessionID: sessionID, + authorization: "Bearer read-only-token" + ) + ) + + #expect(response.statusCode == 403) + let challenge = response.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("error=\"insufficient_scope\"") == true) + #expect(challenge?.contains("scope=\"files:read files:write\"") == true) + #expect( + challenge?.contains("resource_metadata=\"\(authResourceMetadataURL.absoluteString)\"") + == true + ) + #expect( + challenge?.contains( + "error_description=\"Additional file write permission required\"" + ) == true + ) + + await transport.disconnect() + } + // MARK: - Resumability @Test("GET with Last-Event-ID replays stored events") diff --git a/Tests/MCPTests/MockResponses.swift b/Tests/MCPTests/MockResponses.swift new file mode 100644 index 00000000..1c7728c9 --- /dev/null +++ b/Tests/MCPTests/MockResponses.swift @@ -0,0 +1,216 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if swift(>=6.1) + + enum MockResponses { + typealias Route = @Sendable (URLRequest) async throws -> (HTTPURLResponse, Data) + + static func ephemeralConfiguration() -> URLSessionConfiguration { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockURLProtocol.self] + return config + } + + static func mockError(_ message: String) -> NSError { + NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [NSLocalizedDescriptionKey: message] + ) + } + + static func jsonRPCResult(id: Int) -> Data { + #"{"jsonrpc":"2.0","result":{"ok":true},"id":\#(id)}"#.data(using: .utf8)! + } + + // MARK: - Route Builders + + static func jsonSuccess(body: Data) -> Route { + { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, body) + } + } + + static func bearerChallenge( + statusCode: Int = 401, + resourceMetadataURL: URL? = nil, + scope: String? = nil, + error: String? = nil, + errorDescription: String? = nil + ) -> Route { + { request in + var params: [String] = [] + if let url = resourceMetadataURL { + params.append("resource_metadata=\"\(url.absoluteString)\"") + } + if let scope { params.append("scope=\"\(scope)\"") } + if let error { params.append("error=\"\(error)\"") } + if let errorDescription { params.append("error_description=\"\(errorDescription)\"") } + let headerValue = "Bearer \(params.joined(separator: ", "))" + let response = HTTPURLResponse( + url: request.url!, statusCode: statusCode, httpVersion: "HTTP/1.1", + headerFields: ["WWW-Authenticate": headerValue])! + return (response, Data()) + } + } + + static func resourceMetadata( + authorizationServers: [String], + scopesSupported: [String]? = nil, + resource: String? = nil + ) -> Route { + { request in + var dict: [String: Any] = ["authorization_servers": authorizationServers] + if let scopes = scopesSupported { dict["scopes_supported"] = scopes } + if let resource { dict["resource"] = resource } + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func asMetadata( + issuer: String, + tokenEndpoint: String, + authorizationEndpoint: String? = nil, + registrationEndpoint: String? = nil, + codeChallengeMethodsSupported: [String]? = nil, + tokenEndpointAuthMethodsSupported: [String]? = nil, + clientIDMetadataDocumentSupported: Bool? = nil + ) -> Route { + { request in + var dict: [String: Any] = ["issuer": issuer, "token_endpoint": tokenEndpoint] + if let v = authorizationEndpoint { dict["authorization_endpoint"] = v } + if let v = registrationEndpoint { dict["registration_endpoint"] = v } + if let v = codeChallengeMethodsSupported { + dict["code_challenge_methods_supported"] = v + } + if let v = tokenEndpointAuthMethodsSupported { + dict["token_endpoint_auth_methods_supported"] = v + } + if let v = clientIDMetadataDocumentSupported { + dict["client_id_metadata_document_supported"] = v + } + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func tokenSuccess( + accessToken: String, + expiresIn: Int = 3600, + scope: String? = nil, + refreshToken: String? = nil + ) -> Route { + { request in + var dict: [String: Any] = [ + "access_token": accessToken, "token_type": "Bearer", "expires_in": expiresIn, + ] + if let scope { dict["scope"] = scope } + if let refreshToken { dict["refresh_token"] = refreshToken } + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func tokenResponse( + accessToken: String, + tokenType: String, + expiresIn: Int = 3600 + ) -> Route { + { request in + let dict: [String: Any] = [ + "access_token": accessToken, "token_type": tokenType, "expires_in": expiresIn, + ] + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func tokenError( + statusCode: Int = 400, + error: String, + errorDescription: String? = nil, + extraFields: [String: String] = [:] + ) -> Route { + { request in + var dict: [String: Any] = ["error": error] + if let errorDescription { dict["error_description"] = errorDescription } + for (key, value) in extraFields { dict[key] = value } + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: statusCode, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func httpError(statusCode: Int) -> Route { + { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: statusCode, httpVersion: "HTTP/1.1", + headerFields: nil)! + return (response, Data()) + } + } + + static func registrationSuccess( + clientID: String, + clientSecret: String? = nil + ) -> Route { + { request in + var dict: [String: Any] = ["client_id": clientID] + if let clientSecret { dict["client_secret"] = clientSecret } + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 201, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + } + } + + static func redirect(to location: String, statusCode: Int = 302) -> Route { + { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: statusCode, httpVersion: "HTTP/1.1", + headerFields: ["Location": location])! + return (response, Data()) + } + } + + // MARK: - Routing Handler + + static func routingHandler( + routes: [URL: Route] + ) -> @Sendable (URLRequest) async throws -> (HTTPURLResponse, Data) { + { request in + guard let url = request.url else { throw mockError("Missing request URL") } + guard let handler = routes[url] else { + throw mockError("Unexpected URL: \(url.absoluteString)") + } + return try await handler(request) + } + } + } + +#endif // swift(>=6.1) diff --git a/Tests/MCPTests/OAuthAuthorizationCodeFlowTests.swift b/Tests/MCPTests/OAuthAuthorizationCodeFlowTests.swift new file mode 100644 index 00000000..85c610e4 --- /dev/null +++ b/Tests/MCPTests/OAuthAuthorizationCodeFlowTests.swift @@ -0,0 +1,178 @@ +import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +@Suite("OAuthAuthorizationCodeFlow") +struct OAuthAuthorizationCodeFlowTests { + + let flow = OAuthAuthorizationCodeFlow() + let authEndpoint = URL(string: "https://auth.example.com/authorize")! + let resource = URL(string: "https://api.example.com")! + let redirectURI = URL(string: "https://app.example.com/callback")! + let scopeSelector = DefaultOAuthScopeSelector() + + // MARK: - buildURL + + @Test("buildURL includes all required parameters") + func testBuildURLRequiredParams() throws { + let url = try flow.buildURL( + authorizationEndpoint: authEndpoint, + resource: resource, + redirectURI: redirectURI, + clientID: "my-client", + codeChallenge: "abc123", + scopes: nil, + state: "state-xyz", + scopeSerializer: scopeSelector + ) + let components = URLComponents(url: url, resolvingAgainstBaseURL: false)! + let items = Dictionary( + uniqueKeysWithValues: components.queryItems!.map { ($0.name, $0.value ?? "") }) + + #expect(items["response_type"] == "code") + #expect(items["client_id"] == "my-client") + #expect(items["redirect_uri"] == redirectURI.absoluteString) + #expect(items["state"] == "state-xyz") + #expect(items["resource"] == resource.absoluteString) + #expect(items["code_challenge"] == "abc123") + #expect(items["code_challenge_method"] == "S256") + } + + @Test("buildURL includes scope when provided") + func testBuildURLWithScope() throws { + let url = try flow.buildURL( + authorizationEndpoint: authEndpoint, + resource: resource, + redirectURI: redirectURI, + clientID: "my-client", + codeChallenge: "abc123", + scopes: Set(["read", "write"]), + state: "state-xyz", + scopeSerializer: scopeSelector + ) + let components = URLComponents(url: url, resolvingAgainstBaseURL: false)! + let items = Dictionary( + uniqueKeysWithValues: components.queryItems!.map { ($0.name, $0.value ?? "") }) + + let scope = items["scope"] ?? "" + #expect(scope.contains("read")) + #expect(scope.contains("write")) + } + + @Test("buildURL omits scope for nil scopes") + func testBuildURLOmitsScopeWhenNil() throws { + let url = try flow.buildURL( + authorizationEndpoint: authEndpoint, + resource: resource, + redirectURI: redirectURI, + clientID: "my-client", + codeChallenge: "abc123", + scopes: nil, + state: "state-xyz", + scopeSerializer: scopeSelector + ) + let components = URLComponents(url: url, resolvingAgainstBaseURL: false)! + let items = components.queryItems ?? [] + #expect(!items.contains(where: { $0.name == "scope" })) + } + + @Test("buildURL omits scope for empty scope set") + func testBuildURLOmitsScopeWhenEmpty() throws { + let url = try flow.buildURL( + authorizationEndpoint: authEndpoint, + resource: resource, + redirectURI: redirectURI, + clientID: "my-client", + codeChallenge: "abc123", + scopes: Set(), + state: "state-xyz", + scopeSerializer: scopeSelector + ) + let components = URLComponents(url: url, resolvingAgainstBaseURL: false)! + let items = components.queryItems ?? [] + #expect(!items.contains(where: { $0.name == "scope" })) + } + + // MARK: - extractCode + + @Test("extractCode returns code from valid redirect URL") + func testExtractCodeSuccess() throws { + let redirectURL = URL(string: + "https://app.example.com/callback?code=auth-code-123&state=my-state")! + let code = try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + #expect(code == "auth-code-123") + } + + @Test("extractCode throws state mismatch") + func testExtractCodeThrowsStateMismatch() { + let redirectURL = URL(string: + "https://app.example.com/callback?code=auth-code-123&state=wrong-state")! + #expect(throws: OAuthAuthorizationError.self) { + try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + } + } + + @Test("extractCode throws missing state") + func testExtractCodeThrowsMissingState() { + let redirectURL = URL(string: + "https://app.example.com/callback?code=auth-code-123")! + #expect(throws: OAuthAuthorizationError.self) { + try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + } + } + + @Test("extractCode throws missing code") + func testExtractCodeThrowsMissingCode() { + let redirectURL = URL(string: + "https://app.example.com/callback?state=my-state")! + #expect(throws: OAuthAuthorizationError.self) { + try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + } + } + + @Test("extractCode throws redirect URI mismatch") + func testExtractCodeThrowsRedirectMismatch() { + let redirectURL = URL(string: + "https://evil.example.com/callback?code=auth-code-123&state=my-state")! + #expect(throws: OAuthAuthorizationError.self) { + try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + } + } + + @Test("extractCode normalizes host case in redirect URI comparison") + func testExtractCodeCaseInsensitiveHost() throws { + let redirectURL = URL(string: + "https://APP.EXAMPLE.COM/callback?code=auth-code-123&state=my-state")! + let code = try flow.extractCode( + from: redirectURL, + expectedRedirectURI: redirectURI, + expectedState: "my-state" + ) + #expect(code == "auth-code-123") + } +} diff --git a/Tests/MCPTests/OAuthAuthorizationTests.swift b/Tests/MCPTests/OAuthAuthorizationTests.swift new file mode 100644 index 00000000..b28026f9 --- /dev/null +++ b/Tests/MCPTests/OAuthAuthorizationTests.swift @@ -0,0 +1,735 @@ +import Foundation +import Testing + +#if canImport(CryptoKit) +import CryptoKit +#endif + +@testable import MCP + +@Suite("OAuth Authorization Helpers") +struct OAuthAuthorizationTests { + #if canImport(CryptoKit) + // Generated fresh each test run — no hardcoded key material in source. + private static let testPrivateKeyPEM: String = P256.Signing.PrivateKey().pemRepresentation + #else + private static let testPrivateKeyPEM: String = "" + #endif + + private var metadataDiscovery: DefaultOAuthMetadataDiscovery { DefaultOAuthMetadataDiscovery() } + + private static func decodeBase64URL(_ input: Substring) -> Data? { + var base64 = String(input) + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = base64.count % 4 + if remainder != 0 { + base64 += String(repeating: "=", count: 4 - remainder) + } + return Data(base64Encoded: base64) + } + + @Test("Parse Bearer challenge with resource metadata and scope") + func testParseBearerChallenge() { + let headers = [ + "WWW-Authenticate": + "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\", scope=\"files:read files:write\", error=\"insufficient_scope\"" + ] + + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: headers) + + #expect(challenge != nil) + #expect( + challenge?.resourceMetadataURL + == URL(string: "https://mcp.example.com/.well-known/oauth-protected-resource")) + #expect(challenge?.scope == "files:read files:write") + #expect(challenge?.error == "insufficient_scope") + } + + @Test("Parse Bearer challenge with optional error description") + func testParseBearerChallengeErrorDescription() { + let headers = [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"files:read files:write\", resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\", error_description=\"Additional file write permission required\"" + ] + + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: headers) + + #expect(challenge != nil) + #expect(challenge?.error == "insufficient_scope") + #expect(challenge?.scope == "files:read files:write") + #expect( + challenge?.resourceMetadataURL + == URL(string: "https://mcp.example.com/.well-known/oauth-protected-resource")) + #expect(challenge?.errorDescription == "Additional file write permission required") + } + + @Test("Parse Bearer challenge when another auth scheme appears first") + func testParseBearerChallengeWhenBearerIsNotFirst() { + let headers = [ + "WWW-Authenticate": + "Basic realm=\"legacy\", Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\", scope=\"files:read\"" + ] + + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: headers) + + #expect(challenge != nil) + #expect( + challenge?.resourceMetadataURL + == URL(string: "https://mcp.example.com/.well-known/oauth-protected-resource")) + #expect(challenge?.scope == "files:read") + } + + @Test("Parse Bearer challenge stops before the next auth scheme") + func testParseBearerChallengeStopsBeforeNextScheme() { + let headers = [ + "WWW-Authenticate": + "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\", scope=\"files:read\", DPoP algs=\"ES256\"" + ] + + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: headers) + + #expect(challenge != nil) + #expect( + challenge?.resourceMetadataURL + == URL(string: "https://mcp.example.com/.well-known/oauth-protected-resource")) + #expect(challenge?.scope == "files:read") + #expect(challenge?.parameters["algs"] == nil) + } + + @Test("Protected resource metadata discovery fallback URLs") + func testProtectedResourceMetadataURLs() { + let endpoint = URL(string: "https://example.com/public/mcp")! + let urls = metadataDiscovery.protectedResourceMetadataURLs(for: endpoint) + + #expect( + urls == [ + URL(string: "https://example.com/.well-known/oauth-protected-resource/public/mcp")!, + URL(string: "https://example.com/.well-known/oauth-protected-resource")!, + ]) + } + + @Test("Authorization server metadata discovery URLs for issuer with path") + func testAuthorizationServerMetadataURLsWithPath() { + let issuer = URL(string: "https://auth.example.com/tenant1")! + let urls = metadataDiscovery.authorizationServerMetadataURLs(for: issuer) + + #expect( + urls == [ + URL(string: "https://auth.example.com/.well-known/oauth-authorization-server/tenant1")!, + URL(string: "https://auth.example.com/.well-known/openid-configuration/tenant1")!, + URL(string: "https://auth.example.com/tenant1/.well-known/openid-configuration")!, + ]) + } + + @Test("Authorization server metadata discovery URLs for issuer without path") + func testAuthorizationServerMetadataURLsWithoutPath() { + let issuer = URL(string: "https://auth.example.com")! + let urls = metadataDiscovery.authorizationServerMetadataURLs(for: issuer) + + #expect( + urls == [ + URL(string: "https://auth.example.com/.well-known/oauth-authorization-server")!, + URL(string: "https://auth.example.com/.well-known/openid-configuration")!, + ]) + } + + @Test("Canonical resource URI normalization") + func testCanonicalResourceURINormalization() throws { + let endpoint = URL(string: "HTTPS://MCP.EXAMPLE.COM/?q=1")! + let canonical = try metadataDiscovery.canonicalResourceURI(from: endpoint) + #expect(canonical.absoluteString == "https://mcp.example.com") + } + + @Test("Canonical resource URI supports explicit port and root slash normalization") + func testCanonicalResourceURIWithExplicitPort() throws { + let endpoint = URL(string: "HTTPS://MCP.EXAMPLE.COM:8443/")! + let canonical = try metadataDiscovery.canonicalResourceURI(from: endpoint) + #expect(canonical.absoluteString == "https://mcp.example.com:8443") + } + + @Test("Canonical resource URI preserves specific server path") + func testCanonicalResourceURIPreservesPath() throws { + let endpoint = URL(string: "https://mcp.example.com/server/mcp")! + let canonical = try metadataDiscovery.canonicalResourceURI(from: endpoint) + #expect(canonical.absoluteString == "https://mcp.example.com/server/mcp") + } + + @Test("Canonical resource URI rejects missing scheme") + func testCanonicalResourceURIRejectsMissingScheme() { + #expect(throws: OAuthAuthorizationError.self) { + _ = try metadataDiscovery.canonicalResourceURI(from: URL(string: "mcp.example.com")!) + } + } + + @Test("Canonical resource URI rejects insecure non-loopback http scheme") + func testCanonicalResourceURIRejectsNonLoopbackHTTP() { + #expect(throws: OAuthAuthorizationError.self) { + _ = try metadataDiscovery.canonicalResourceURI( + from: URL(string: "http://mcp.example.com/resource")! + ) + } + } + + @Test("Canonical resource URI allows loopback http scheme") + func testCanonicalResourceURIAllowsLoopbackHTTP() throws { + let canonical = try metadataDiscovery.canonicalResourceURI( + from: URL(string: "http://localhost:8080/mcp")! + ) + #expect(canonical.absoluteString == "http://localhost:8080/mcp") + } + + @Test("Canonical resource URI rejects fragment") + func testCanonicalResourceURIRejectsFragment() { + #expect(throws: OAuthAuthorizationError.self) { + _ = try metadataDiscovery.canonicalResourceURI( + from: URL(string: "https://mcp.example.com#fragment")! + ) + } + } + + @Test("Protected resource matching allows same-origin parent resource") + func testProtectedResourceMatchingParentResource() { + let resource = URL(string: "https://mcp.example.com")! + let endpoint = URL(string: "https://mcp.example.com/mcp")! + #expect(metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: endpoint)) + } + + @Test("Protected resource matching enforces path boundaries") + func testProtectedResourceMatchingPathBoundary() { + let resource = URL(string: "https://mcp.example.com/mcp")! + let validEndpoint = URL(string: "https://mcp.example.com/mcp/tools")! + let invalidEndpoint = URL(string: "https://mcp.example.com/mcp2")! + + #expect(metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: validEndpoint)) + #expect(!metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: invalidEndpoint)) + } + + @Test("Protected resource matching rejects origin mismatches") + func testProtectedResourceMatchingOriginMismatch() { + let resource = URL(string: "https://evil.example.com/mcp")! + let endpoint = URL(string: "https://mcp.example.com/mcp")! + #expect(!metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: endpoint)) + } + + @Test("Scope selection prefers challenge scope") + func testScopeSelection() { + let selected = DefaultOAuthScopeSelector().selectScopes( + challengeScope: "files:read", + scopesSupported: ["files:read", "files:write"] + ) + #expect(selected == Set(["files:read"])) + + let fallback = DefaultOAuthScopeSelector().selectScopes( + challengeScope: nil, + scopesSupported: ["files:read", "files:write"] + ) + #expect(fallback == Set(["files:read", "files:write"])) + + let omitted = DefaultOAuthScopeSelector().selectScopes(challengeScope: nil, scopesSupported: nil) + #expect(omitted == nil) + } + + #if canImport(CryptoKit) + @Test("private_key_jwt helper builds signed JWT with expected claims") + func testPrivateKeyJWTAssertionHelper() throws { + let tokenEndpoint = URL(string: "https://auth.example.com/oauth/token")! + let assertion = try OAuthConfiguration.makePrivateKeyJWTAssertion( + clientID: "test-client", + tokenEndpoint: tokenEndpoint, + privateKeyPEM: Self.testPrivateKeyPEM, + audience: "https://auth.example.com", + issuedAt: Date(timeIntervalSince1970: 1_700_000_000), + expiresIn: 300 + ) + + let parts = assertion.split(separator: ".") + #expect(parts.count == 3) + + let headerData = Self.decodeBase64URL(parts[0]) + let payloadData = Self.decodeBase64URL(parts[1]) + #expect(headerData != nil) + #expect(payloadData != nil) + + let header = try JSONSerialization.jsonObject(with: headerData!) as? [String: Any] + let payload = try JSONSerialization.jsonObject(with: payloadData!) as? [String: Any] + + #expect(header?["alg"] as? String == "ES256") + #expect(header?["typ"] as? String == "JWT") + #expect(payload?["iss"] as? String == "test-client") + #expect(payload?["sub"] as? String == "test-client") + #expect(payload?["aud"] as? String == "https://auth.example.com") + #expect(payload?["iat"] as? Int == 1_700_000_000) + #expect(payload?["exp"] as? Int == 1_700_000_300) + #expect((payload?["jti"] as? String)?.isEmpty == false) + } + #endif + + @Test("private_key_jwt helper rejects non-positive lifetime") + func testPrivateKeyJWTAssertionRejectsNonPositiveLifetime() { + #expect(throws: OAuthConfiguration.PrivateKeyJWTAssertionError.self) { + _ = try OAuthConfiguration.makePrivateKeyJWTAssertion( + clientID: "test-client", + tokenEndpoint: URL(string: "https://auth.example.com/token")!, + privateKeyPEM: Self.testPrivateKeyPEM, + expiresIn: 0 + ) + } + } + + // MARK: - WWW-Authenticate Parser Edge Cases + + @Test("Parse Bearer returns nil when WWW-Authenticate header is absent") + func testParseBearerReturnsNilWhenHeaderAbsent() { + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: ["Content-Type": "text/plain"]) + #expect(challenge == nil) + } + + @Test("Parse Bearer returns nil for empty header value") + func testParseBearerReturnsNilForEmptyValue() { + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: ["WWW-Authenticate": ""]) + #expect(challenge == nil) + } + + @Test("Parse Bearer returns empty challenge for bare Bearer with no parameters") + func testParseBearerBareBearerScheme() { + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer(from: ["WWW-Authenticate": "Bearer"]) + #expect(challenge != nil) + #expect(challenge?.parameters.isEmpty == true) + } + + @Test("Parse Bearer header name lookup is case-insensitive") + func testParseBearerCaseInsensitiveHeaderName() { + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer( + from: ["www-authenticate": "Bearer scope=\"read\""] + ) + #expect(challenge != nil) + #expect(challenge?.scope == "read") + } + + @Test("Parse Bearer returns nil when only non-Bearer schemes present") + func testParseBearerReturnsNilForNonBearerScheme() { + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer( + from: ["WWW-Authenticate": "Basic realm=\"example\""] + ) + #expect(challenge == nil) + } + + // MARK: - Metadata Discovery Edge Cases + + @Test("Protected resource metadata returns empty for non-HTTPS endpoint") + func testProtectedResourceMetadataRejectsHTTP() { + let urls = metadataDiscovery.protectedResourceMetadataURLs( + for: URL(string: "http://remote.example.com/mcp")! + ) + #expect(urls.isEmpty) + } + + @Test("Protected resource metadata for root endpoint produces path-specific and root URLs") + func testProtectedResourceMetadataRootEndpoint() { + let urls = metadataDiscovery.protectedResourceMetadataURLs( + for: URL(string: "https://example.com")! + ) + #expect(urls.count == 2) + #expect(urls.allSatisfy { + $0 == URL(string: "https://example.com/.well-known/oauth-protected-resource")! + }) + } + + @Test("Protected resource metadata allows loopback HTTP") + func testProtectedResourceMetadataAllowsLoopbackHTTP() { + let urls = metadataDiscovery.protectedResourceMetadataURLs( + for: URL(string: "http://localhost:8080/mcp")! + ) + #expect(!urls.isEmpty) + } + + @Test("Authorization server metadata returns empty for non-HTTPS issuer") + func testAuthorizationServerMetadataRejectsHTTP() { + let urls = metadataDiscovery.authorizationServerMetadataURLs( + for: URL(string: "http://remote.example.com")! + ) + #expect(urls.isEmpty) + } + + @Test("Authorization server fallback issuer derives origin from endpoint") + func testAuthorizationServerFallbackIssuer() throws { + let issuer = try metadataDiscovery.authorizationServerFallbackIssuer( + from: URL(string: "https://mcp.example.com:8443/server/mcp")! + ) + #expect(issuer.absoluteString == "https://mcp.example.com:8443") + } + + @Test("Authorization server fallback issuer normalizes case and strips query") + func testAuthorizationServerFallbackIssuerNormalization() throws { + let issuer = try metadataDiscovery.authorizationServerFallbackIssuer( + from: URL(string: "HTTPS://AUTH.EXAMPLE.COM/path?q=1")! + ) + #expect(issuer.absoluteString == "https://auth.example.com") + } + + @Test("Authorization server fallback issuer rejects non-HTTPS non-loopback") + func testAuthorizationServerFallbackIssuerRejectsInsecure() { + #expect(throws: OAuthAuthorizationError.self) { + _ = try metadataDiscovery.authorizationServerFallbackIssuer( + from: URL(string: "http://remote.example.com/mcp")! + ) + } + } + + @Test("Protected resource matching with port mismatch") + func testProtectedResourceMatchingPortMismatch() { + let resource = URL(string: "https://mcp.example.com:8443")! + let endpoint = URL(string: "https://mcp.example.com:9443/mcp")! + #expect(!metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: endpoint)) + } + + @Test("Protected resource matching with exact path") + func testProtectedResourceMatchingExactPath() { + let resource = URL(string: "https://mcp.example.com/mcp")! + let endpoint = URL(string: "https://mcp.example.com/mcp")! + #expect(metadataDiscovery.protectedResourceMatches(resource: resource, endpoint: endpoint)) + } + + // MARK: - Scope Selector Edge Cases + + @Test("parseScopeString splits on whitespace and ignores empty tokens") + func testParseScopeString() { + let scopes = DefaultOAuthScopeSelector().parseScopeString("files:read files:write\tprofile") + #expect(scopes == Set(["files:read", "files:write", "profile"])) + } + + @Test("parseScopeString returns empty set for blank string") + func testParseScopeStringEmpty() { + let scopes = DefaultOAuthScopeSelector().parseScopeString(" ") + #expect(scopes.isEmpty) + } + + @Test("serialize produces sorted space-separated string") + func testSerializeScopes() { + let result = DefaultOAuthScopeSelector().serialize(Set(["write", "read", "admin"])) + #expect(result == "admin read write") + } + + @Test("serialize returns nil for empty set") + func testSerializeScopesEmpty() { + let result = DefaultOAuthScopeSelector().serialize(Set()) + #expect(result == nil) + } + + @Test("selectScopes returns nil for empty challenge scope string") + func testSelectScopesEmptyChallengeScope() { + let result = DefaultOAuthScopeSelector().selectScopes( + challengeScope: " ", + scopesSupported: ["files:read"] + ) + #expect(result == nil) + } + + @Test("selectScopes returns nil for empty scopesSupported array") + func testSelectScopesEmptyScopesSupported() { + let result = DefaultOAuthScopeSelector().selectScopes( + challengeScope: nil, + scopesSupported: [] + ) + #expect(result == nil) + } + + // MARK: - Token Type Validation + + @Test("makePrivateKeyJWTAssertion — token type empty string is rejected") + func testEmptyTokenTypeIsRejected() throws { + // OAuthTokenResponse decodes token_type from JSON; simulate an empty value via + // direct struct construction using the internal initializer path. + let json = #"{"access_token":"tok","token_type":"","expires_in":3600}"# + let data = json.data(using: .utf8)! + let decoded = try JSONDecoder().decode(OAuthTokenResponse.self, from: data) + // The token_type is empty; the authorizer guard must reject this. + let tokenType = decoded.tokenType.trimmingCharacters(in: .whitespacesAndNewlines) + let isValid = !tokenType.isEmpty + && tokenType.caseInsensitiveCompare("Bearer") == .orderedSame + #expect(!isValid, "Empty token_type must not be accepted as Bearer") + } + + @Test("Token type whitespace-only is rejected") + func testWhitespaceTokenTypeIsRejected() throws { + let json = #"{"access_token":"tok","token_type":" ","expires_in":3600}"# + let data = json.data(using: .utf8)! + let decoded = try JSONDecoder().decode(OAuthTokenResponse.self, from: data) + let tokenType = decoded.tokenType.trimmingCharacters(in: .whitespacesAndNewlines) + let isValid = !tokenType.isEmpty + && tokenType.caseInsensitiveCompare("Bearer") == .orderedSame + #expect(!isValid, "Whitespace-only token_type must not be accepted as Bearer") + } + + // MARK: - InMemoryTokenStorage + + @Test("InMemoryTokenStorage save and load round-trip") + func testTokenStorageSaveAndLoad() { + let storage = InMemoryTokenStorage() + #expect(storage.load() == nil) + + let token = OAuthAccessToken( + value: "access-123", + tokenType: "Bearer", + expiresAt: nil, + scopes: ["read"], + authorizationServer: nil, + refreshToken: nil + ) + storage.save(token) + + let loaded = storage.load() + #expect(loaded?.value == "access-123") + #expect(loaded?.tokenType == "Bearer") + #expect(loaded?.scopes == Set(["read"])) + } + + @Test("InMemoryTokenStorage clear removes stored token") + func testTokenStorageClear() { + let storage = InMemoryTokenStorage() + let token = OAuthAccessToken( + value: "access-456", + tokenType: "Bearer", + expiresAt: nil, + scopes: [], + authorizationServer: nil, + refreshToken: nil + ) + storage.save(token) + #expect(storage.load() != nil) + + storage.clear() + #expect(storage.load() == nil) + } + + @Test("InMemoryTokenStorage save overwrites previous token") + func testTokenStorageSaveOverwrites() { + let storage = InMemoryTokenStorage() + let first = OAuthAccessToken( + value: "first", + tokenType: "Bearer", + expiresAt: nil, + scopes: [], + authorizationServer: nil, + refreshToken: nil + ) + let second = OAuthAccessToken( + value: "second", + tokenType: "Bearer", + expiresAt: nil, + scopes: [], + authorizationServer: nil, + refreshToken: nil + ) + storage.save(first) + storage.save(second) + + #expect(storage.load()?.value == "second") + } + + // MARK: - Proactive Token Refresh + + @Test("prepareAuthorization does nothing when proactiveRefreshWindowSeconds is zero") + func testPrepareAuthorizationSkipsWhenWindowIsZero() async throws { + let storage = InMemoryTokenStorage() + // Token expires in 50 s — within a large proactive window but not within the 30 s default skew. + storage.save(OAuthAccessToken( + value: "original-token", + tokenType: "Bearer", + expiresAt: Date().addingTimeInterval(50), + scopes: [], + authorizationServer: nil, + refreshToken: "some-refresh-token" + )) + let config = OAuthConfiguration( + authentication: .none(clientID: "client"), + proactiveRefreshWindowSeconds: 0 + ) + let authorizer = OAuthAuthorizer(configuration: config, tokenStorage: storage) + + try await authorizer.prepareAuthorization( + for: URL(string: "https://example.com/mcp")!, + session: .shared + ) + + // Token must be unchanged — proactive refresh is disabled. + #expect(storage.load()?.value == "original-token") + } + + @Test("prepareAuthorization does nothing without cached authorization server metadata") + func testPrepareAuthorizationSkipsWithoutCachedASMetadata() async throws { + let storage = InMemoryTokenStorage() + // Token expires in 50 s — within the 400 s proactive window. + storage.save(OAuthAccessToken( + value: "original-token", + tokenType: "Bearer", + expiresAt: Date().addingTimeInterval(50), + scopes: [], + authorizationServer: nil, + refreshToken: "some-refresh-token" + )) + let config = OAuthConfiguration( + authentication: .none(clientID: "client"), + proactiveRefreshWindowSeconds: 400 + ) + // No handleChallenge call → authorizationServerMetadata remains nil. + let authorizer = OAuthAuthorizer(configuration: config, tokenStorage: storage) + + try await authorizer.prepareAuthorization( + for: URL(string: "https://example.com/mcp")!, + session: .shared + ) + + // Token must be unchanged — refresh requires cached AS metadata. + #expect(storage.load()?.value == "original-token") + } + + // MARK: - WWW-Authenticate Bare Scheme Detection + + @Test("Bare scheme Digest terminates preceding Bearer parameter collection") + func testBareSchemeTerminatesBearer() { + // "Digest" has no params and no '=', so it must start a new challenge and stop + // the Bearer parameter collector before "scope" from the second Bearer leaks in. + let header = #"Bearer scope="a", Digest, Bearer scope="b""# + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer( + from: ["WWW-Authenticate": header]) + #expect(challenge != nil) + // The outer loop stops at the first Bearer, so scope must be "a", not "b". + #expect(challenge?.scope == "a") + } + + @Test("Non-Bearer leading scheme followed by Bearer is parsed correctly") + func testNonBearerLeadingSchemeFollowedByBearer() { + let header = #"SomeScheme, Bearer scope="test""# + let challenge = DefaultOAuthWWWAuthenticateParser().parseBearer( + from: ["WWW-Authenticate": header]) + #expect(challenge != nil) + #expect(challenge?.scope == "test") + } + + // MARK: - Private IP Blocking + + @Test("privateIPAddressBlocked error has informative description") + func testPrivateIPAddressBlockedErrorDescription() { + let url = "https://169.254.169.254/.well-known/oauth-protected-resource" + let error = OAuthAuthorizationError.privateIPAddressBlocked( + context: "Protected resource metadata URL", url: url) + let description = error.errorDescription ?? "" + #expect(description.contains("private or reserved IP")) + #expect(description.contains(url)) + } + + @Test("cimdNotSupported is thrown when CIMD URL provided but server does not support it") + func testCIMDNotSupportedError() { + // Verify the error case exists and its description is meaningful. + let error = OAuthAuthorizationError.cimdNotSupported( + clientID: "https://client.example.com/client-metadata.json") + let description = error.errorDescription ?? "" + #expect(description.contains("Client ID Metadata Document")) + #expect(description.contains("client.example.com")) + } + + // MARK: - BearerTokenValidator Audience and Expiry Tests + + private let testResourceMetadataURL = + URL(string: "https://api.example.com/.well-known/oauth-protected-resource")! + private let testResourceIdentifier = URL(string: "https://api.example.com")! + + private func makeBearerValidator( + tokenValidator: @escaping BearerTokenValidator.TokenValidator + ) -> BearerTokenValidator { + BearerTokenValidator( + resourceMetadataURL: testResourceMetadataURL, + resourceIdentifier: testResourceIdentifier, + tokenValidator: tokenValidator + ) + } + + private func makeRequest(authorization: String) -> HTTPRequest { + HTTPRequest( + method: "POST", + headers: [HTTPHeaderName.authorization: authorization], + path: "/mcp" + ) + } + + private func makeContext() -> HTTPValidationContext { + HTTPValidationContext(httpMethod: "POST", isInitializationRequest: false) + } + + @Test("BearerTokenValidator allows valid token with nil audience (opaque token)") + func testBearerValidatorNilAudienceAllows() { + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(audience: nil)) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result == nil) + } + + @Test("BearerTokenValidator allows token with matching audience") + func testBearerValidatorMatchingAudienceAllows() { + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(audience: ["https://api.example.com"])) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result == nil) + } + + @Test("BearerTokenValidator allows token when one aud entry matches") + func testBearerValidatorOneOfMultipleAudienceMatches() { + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(audience: ["https://other.example.com", "https://api.example.com"])) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result == nil) + } + + @Test("BearerTokenValidator returns 401 invalid_token when audience does not match") + func testBearerValidatorAudienceMismatchReturns401() { + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(audience: ["https://other.example.com"])) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result?.statusCode == 401) + let challenge = result?.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("error=\"invalid_token\"") == true) + #expect(challenge?.contains("Token audience mismatch") == true) + } + + @Test("BearerTokenValidator returns 401 invalid_token when token is expired") + func testBearerValidatorExpiredTokenReturns401() { + let pastDate = Date(timeIntervalSinceNow: -3600) + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(expiresAt: pastDate)) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result?.statusCode == 401) + let challenge = result?.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("error=\"invalid_token\"") == true) + #expect(challenge?.contains("Token has expired") == true) + } + + @Test("BearerTokenValidator allows token that has not yet expired") + func testBearerValidatorNonExpiredTokenAllows() { + let futureDate = Date(timeIntervalSinceNow: 3600) + let validator = makeBearerValidator { _, _, _ in + .valid(BearerTokenInfo(expiresAt: futureDate)) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result == nil) + } + + @Test("BearerTokenValidator checks expiry before audience") + func testBearerValidatorExpiryCheckedBeforeAudience() { + let pastDate = Date(timeIntervalSinceNow: -1) + let validator = makeBearerValidator { _, _, _ in + // Expired token but matching audience — expiry should win + .valid(BearerTokenInfo(audience: ["https://api.example.com"], expiresAt: pastDate)) + } + let result = validator.validate(makeRequest(authorization: "Bearer tok"), context: makeContext()) + #expect(result?.statusCode == 401) + let challenge = result?.headers[HTTPHeaderName.wwwAuthenticate] + #expect(challenge?.contains("Token has expired") == true) + } +} diff --git a/Tests/MCPTests/OAuthAuthorizerTests.swift b/Tests/MCPTests/OAuthAuthorizerTests.swift new file mode 100644 index 00000000..b079652b --- /dev/null +++ b/Tests/MCPTests/OAuthAuthorizerTests.swift @@ -0,0 +1,384 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +// MARK: - Mock Implementations + +final class MockURLValidator: OAuthURLValidating, @unchecked Sendable { + var validateHTTPSOrLoopbackCallCount = 0 + var validateAuthorizationServerCallCount = 0 + var validateRedirectURICallCount = 0 + var shouldThrow: Error? + + func validateHTTPSOrLoopback(_ url: URL, context: String) throws { + validateHTTPSOrLoopbackCallCount += 1 + if let error = shouldThrow { throw error } + } + + func validateAuthorizationServer(_ url: URL, context: String) throws { + validateAuthorizationServerCallCount += 1 + if let error = shouldThrow { throw error } + } + + func validateRedirectURI(_ url: URL) throws { + validateRedirectURICallCount += 1 + if let error = shouldThrow { throw error } + } + + func isPrivateIPHost(_ host: String) -> Bool { false } +} + +final class MockDiscoveryClient: OAuthDiscoveryFetching, @unchecked Sendable { + var fetchProtectedResourceMetadataCallCount = 0 + var fetchAuthorizationServerMetadataCallCount = 0 + let metadataDiscovery: any OAuthMetadataDiscovering = DefaultOAuthMetadataDiscovery() + + var protectedResourceMetadataResult: OAuthProtectedResourceMetadata + var authorizationServerMetadataResult: (server: URL, metadata: OAuthAuthorizationServerMetadata) + + init( + authorizationServer: URL = URL(string: "https://auth.example.com")!, + tokenEndpoint: URL = URL(string: "https://auth.example.com/token")! + ) { + self.protectedResourceMetadataResult = OAuthProtectedResourceMetadata( + resource: nil, + authorizationServers: [authorizationServer], + scopesSupported: nil + ) + self.authorizationServerMetadataResult = ( + server: authorizationServer, + metadata: OAuthAuthorizationServerMetadata( + issuer: authorizationServer, + authorizationEndpoint: URL(string: "https://auth.example.com/authorize"), + tokenEndpoint: tokenEndpoint, + registrationEndpoint: nil, + codeChallengeMethodsSupported: ["S256"], + tokenEndpointAuthMethodsSupported: nil, + clientIDMetadataDocumentSupported: nil + ) + ) + } + + func fetchProtectedResourceMetadata(candidates: [URL], session: URLSession) async throws -> OAuthProtectedResourceMetadata { + fetchProtectedResourceMetadataCallCount += 1 + return protectedResourceMetadataResult + } + + func fetchAuthorizationServerMetadata(candidates: [URL], session: URLSession) async throws -> (server: URL, metadata: OAuthAuthorizationServerMetadata) { + fetchAuthorizationServerMetadataCallCount += 1 + return authorizationServerMetadataResult + } +} + +final class MockTokenClient: OAuthTokenRequesting, @unchecked Sendable { + var requestCallCount = 0 + var capturedParameters: [String: String]? + var tokenResponse = OAuthTokenResponse( + accessToken: "mock-access-token", + tokenType: "Bearer", + expiresIn: 3600, + scope: nil, + refreshToken: nil + ) + + func request( + parameters: inout [String: String], + endpoint: URL, + authentication: OAuthConfiguration.TokenEndpointAuthentication, + session: URLSession + ) async throws -> OAuthTokenResponse { + requestCallCount += 1 + capturedParameters = parameters + return tokenResponse + } +} + +final class MockClientRegistrar: OAuthClientRegistering, @unchecked Sendable { + var registerCallCount = 0 + + func register( + configuration: OAuthConfiguration, + asMetadata: OAuthAuthorizationServerMetadata, + session: URLSession + ) async throws -> ( + response: OAuthClientRegistrationResponse, + updatedAuthentication: OAuthConfiguration.TokenEndpointAuthentication + )? { + registerCallCount += 1 + return nil + } +} + +final class MockAuthCodeFlow: OAuthAuthorizationCodeFlowing, @unchecked Sendable { + var buildURLCallCount = 0 + var performCallCount = 0 + var authorizationCode = "mock-auth-code" + + func buildURL( + authorizationEndpoint: URL, + resource: URL, + redirectURI: URL, + clientID: String, + codeChallenge: String, + scopes: Set?, + state: String, + scopeSerializer: any OAuthScopeSelecting + ) throws -> URL { + buildURLCallCount += 1 + return URL(string: "https://auth.example.com/authorize?code=stub")! + } + + func perform( + authorizationURL: URL, + redirectURI: URL, + state: String, + delegate: (any OAuthAuthorizationDelegate)?, + session: URLSession + ) async throws -> String { + performCallCount += 1 + return authorizationCode + } +} + +// MARK: - OAuthAuthorizer Invocation Tests + +@Suite("OAuthAuthorizer dependency invocations") +struct OAuthAuthorizerTests { + + let endpoint = URL(string: "https://mcp.example.com/mcp")! + let headers401 = [ + "WWW-Authenticate": + "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\"" + ] + + func makeAuthorizer( + grantType: OAuthConfiguration.GrantType = .clientCredentials, + urlValidator: MockURLValidator = MockURLValidator(), + discoveryClient: MockDiscoveryClient = MockDiscoveryClient(), + tokenClient: MockTokenClient = MockTokenClient(), + registrar: MockClientRegistrar = MockClientRegistrar(), + authCodeFlow: MockAuthCodeFlow = MockAuthCodeFlow() + ) -> OAuthAuthorizer { + let config = OAuthConfiguration( + grantType: grantType, + authentication: .clientSecretBasic(clientID: "client", clientSecret: "secret") + ) + return OAuthAuthorizer( + configuration: config, + urlValidator: urlValidator, + discoveryClient: discoveryClient, + tokenEndpointClient: tokenClient, + clientRegistrar: registrar, + authCodeFlow: authCodeFlow + ) + } + + // MARK: - validateEndpointSecurity + + @Test("validateEndpointSecurity calls urlValidator") + func testValidateEndpointSecurityCallsURLValidator() throws { + let validator = MockURLValidator() + let authorizer = makeAuthorizer(urlValidator: validator) + + try authorizer.validateEndpointSecurity(for: endpoint) + + #expect(validator.validateHTTPSOrLoopbackCallCount == 1) + } + + @Test("validateEndpointSecurity propagates validation error") + func testValidateEndpointSecurityPropagatesError() { + let validator = MockURLValidator() + validator.shouldThrow = OAuthAuthorizationError.insecureOAuthEndpoint( + context: "test", url: "http://example.com") + let authorizer = makeAuthorizer(urlValidator: validator) + + #expect(throws: OAuthAuthorizationError.self) { + try authorizer.validateEndpointSecurity(for: endpoint) + } + } + + // MARK: - handleChallenge (401 — client_credentials) + + @Test("handleChallenge 401 calls discovery and token clients") + func testHandleChallenge401CallsDiscoveryAndTokenClient() async throws { + let discovery = MockDiscoveryClient() + let tokenClient = MockTokenClient() + + let authorizer = makeAuthorizer( + discoveryClient: discovery, + tokenClient: tokenClient + ) + + let handled = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(handled == true) + #expect(discovery.fetchProtectedResourceMetadataCallCount >= 1) + #expect(discovery.fetchAuthorizationServerMetadataCallCount >= 1) + #expect(tokenClient.requestCallCount == 1) + } + + @Test("handleChallenge 401 uses client_credentials grant type parameter") + func testHandleChallenge401ClientCredentialsGrantType() async throws { + let tokenClient = MockTokenClient() + let authorizer = makeAuthorizer(tokenClient: tokenClient) + + _ = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(tokenClient.capturedParameters?["grant_type"] == "client_credentials") + } + + @Test("handleChallenge 401 attaches resource parameter") + func testHandleChallenge401AttachesResourceParameter() async throws { + let tokenClient = MockTokenClient() + let authorizer = makeAuthorizer(tokenClient: tokenClient) + + _ = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(tokenClient.capturedParameters?["resource"] != nil) + } + + // MARK: - handleChallenge (authorization_code) + + #if canImport(CryptoKit) + @Test("handleChallenge 401 calls authCodeFlow for authorization_code grant") + func testHandleChallenge401AuthorizationCodeCallsFlow() async throws { + let authCodeFlow = MockAuthCodeFlow() + let tokenClient = MockTokenClient() + + let config = OAuthConfiguration( + grantType: .authorizationCode, + authentication: .none(clientID: "my-client"), + authorizationRedirectURI: URL(string: "https://app.example.com/callback")! + ) + let authorizer = OAuthAuthorizer( + configuration: config, + urlValidator: MockURLValidator(), + discoveryClient: MockDiscoveryClient(), + tokenEndpointClient: tokenClient, + clientRegistrar: MockClientRegistrar(), + authCodeFlow: authCodeFlow + ) + + _ = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(authCodeFlow.buildURLCallCount == 1) + #expect(authCodeFlow.performCallCount == 1) + #expect(tokenClient.capturedParameters?["grant_type"] == "authorization_code") + #expect(tokenClient.capturedParameters?["code"] == "mock-auth-code") + } + #endif + + // MARK: - handleChallenge (403) + + @Test("handleChallenge 403 returns false for non-insufficient_scope error") + func testHandleChallenge403NonInsufficientScope() async throws { + let authorizer = makeAuthorizer() + + let handled = try await authorizer.handleChallenge( + statusCode: 403, + headers: ["WWW-Authenticate": "Bearer error=\"access_denied\""], + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(handled == false) + } + + @Test("handleChallenge 403 insufficient_scope acquires token with upgraded scopes") + func testHandleChallenge403InsufficientScope() async throws { + let tokenClient = MockTokenClient() + let discovery = MockDiscoveryClient() + + let authorizer = makeAuthorizer( + discoveryClient: discovery, + tokenClient: tokenClient + ) + + let handled = try await authorizer.handleChallenge( + statusCode: 403, + headers: [ + "WWW-Authenticate": + "Bearer error=\"insufficient_scope\", scope=\"admin\"" + ], + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(handled == true) + #expect(tokenClient.requestCallCount == 1) + } + + // MARK: - Client registration + + @Test("handleChallenge calls client registrar when authentication is .none") + func testHandleChallengeCallsRegistrar() async throws { + let registrar = MockClientRegistrar() + let config = OAuthConfiguration( + authentication: .none(clientID: "plain-client")) + let authorizer = OAuthAuthorizer( + configuration: config, + urlValidator: MockURLValidator(), + discoveryClient: MockDiscoveryClient(), + tokenEndpointClient: MockTokenClient(), + clientRegistrar: registrar, + authCodeFlow: MockAuthCodeFlow() + ) + + _ = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(registrar.registerCallCount == 1) + } + + @Test("handleChallenge skips client registrar when credentials are already configured") + func testHandleChallengeSkipsRegistrarWithCredentials() async throws { + let registrar = MockClientRegistrar() + let authorizer = makeAuthorizer(registrar: registrar) + + _ = try await authorizer.handleChallenge( + statusCode: 401, + headers: headers401, + endpoint: endpoint, + operationKey: nil, + session: .shared + ) + + #expect(registrar.registerCallCount == 0) + } +} diff --git a/Tests/MCPTests/OAuthClientRegistrarTests.swift b/Tests/MCPTests/OAuthClientRegistrarTests.swift new file mode 100644 index 00000000..7bdec5dd --- /dev/null +++ b/Tests/MCPTests/OAuthClientRegistrarTests.swift @@ -0,0 +1,194 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if swift(>=6.1) && !os(Linux) + + @Suite("OAuthClientRegistrar", .serialized) + struct OAuthClientRegistrarTests { + + let registrar = OAuthClientRegistrar(urlValidator: OAuthURLValidator()) + let registrationEndpoint = URL(string: "https://auth.example.com/register")! + + func makeASMetadata(registrationEndpoint: URL? = nil) -> OAuthAuthorizationServerMetadata { + OAuthAuthorizationServerMetadata( + issuer: URL(string: "https://auth.example.com"), + authorizationEndpoint: URL(string: "https://auth.example.com/authorize"), + tokenEndpoint: URL(string: "https://auth.example.com/token"), + registrationEndpoint: registrationEndpoint, + codeChallengeMethodsSupported: ["S256"], + tokenEndpointAuthMethodsSupported: nil, + clientIDMetadataDocumentSupported: nil + ) + } + + func makeConfig( + authentication: OAuthConfiguration.TokenEndpointAuthentication = .none(clientID: "") + ) -> OAuthConfiguration { + OAuthConfiguration(authentication: authentication) + } + + func successRegistrationBody(clientID: String = "registered-client") throws -> Data { + let dict: [String: Any] = ["client_id": clientID] + return try JSONSerialization.data(withJSONObject: dict) + } + + // MARK: - Skip Conditions + + @Test("Returns nil when authentication is not .none") + func testRegisterReturnsNilForNonNoneAuth() async throws { + let config = makeConfig( + authentication: .clientSecretBasic(clientID: "id", clientSecret: "secret")) + let (session, _) = makeIsolatedSession() + let result = try await registrar.register( + configuration: config, + asMetadata: makeASMetadata(), + session: session + ) + #expect(result == nil) + } + + @Test("Returns nil when no registration endpoint and no CIMD") + func testRegisterReturnsNilWithoutRegistrationEndpoint() async throws { + let config = makeConfig(authentication: .none(clientID: "plain-client-id")) + let (session, _) = makeIsolatedSession() + let result = try await registrar.register( + configuration: config, + asMetadata: makeASMetadata(registrationEndpoint: nil), + session: session + ) + #expect(result == nil) + } + + // MARK: - CIMD Errors + + @Test("Throws when clientID is HTTPS URL with path but server does not support CIMD") + func testRegisterThrowsCIMDNotSupported() async throws { + let config = makeConfig( + authentication: .none(clientID: "https://client.example.com/metadata.json")) + let asMetadata = OAuthAuthorizationServerMetadata( + issuer: nil, authorizationEndpoint: nil, tokenEndpoint: nil, + registrationEndpoint: nil, codeChallengeMethodsSupported: nil, + tokenEndpointAuthMethodsSupported: nil, clientIDMetadataDocumentSupported: false + ) + let (session, _) = makeIsolatedSession() + + await #expect(throws: OAuthAuthorizationError.self) { + try await registrar.register( + configuration: config, + asMetadata: asMetadata, + session: session + ) + } + } + + // MARK: - Successful Registration + + @Test("Returns registration response and updated authentication on success") + func testRegisterSucceeds() async throws { + let body = try successRegistrationBody(clientID: "registered-client") + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.registrationEndpoint, statusCode: 201, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + let config = makeConfig(authentication: .none(clientID: "")) + let result = try await registrar.register( + configuration: config, + asMetadata: makeASMetadata(registrationEndpoint: registrationEndpoint), + session: session + ) + + let resultValue = try #require(result) + let expected = OAuthConfiguration.TokenEndpointAuthentication.none(clientID: "registered-client") + #expect(resultValue.updatedAuthentication == expected) + } + + @Test("Throws on 4xx registration response") + func testRegisterThrowsOn4xx() async throws { + let errorBody = try JSONSerialization.data( + withJSONObject: ["error": "invalid_client_metadata"]) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.registrationEndpoint, statusCode: 400, + httpVersion: nil, headerFields: nil)! + return (response, errorBody) + } + + let config = makeConfig(authentication: .none(clientID: "")) + await #expect(throws: OAuthAuthorizationError.self) { + try await registrar.register( + configuration: config, + asMetadata: makeASMetadata(registrationEndpoint: registrationEndpoint), + session: session + ) + } + } + + @Test("Throws tokenRequestFailed on non-2xx non-4xx response") + func testRegisterThrowsOn5xx() async throws { + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.registrationEndpoint, statusCode: 503, + httpVersion: nil, headerFields: nil)! + return (response, Data()) + } + + let config = makeConfig(authentication: .none(clientID: "")) + let error = await #expect(throws: OAuthAuthorizationError.self) { + try await registrar.register( + configuration: config, + asMetadata: makeASMetadata(registrationEndpoint: registrationEndpoint), + session: session + ) + } + guard case .tokenRequestFailed(let statusCode, let oauthError) = error else { + Issue.record("Expected tokenRequestFailed, got \(String(describing: error))") + return + } + #expect(statusCode == 503) + #expect(oauthError == nil) + } + + // MARK: - updatedAuthentication helper + + @Test("updatedAuthentication updates client ID and secret for basic auth") + func testUpdatedAuthenticationBasic() { + let registration = OAuthClientRegistrationResponse( + clientID: "new-id", clientSecret: "new-secret", + tokenEndpointAuthMethod: nil, clientSecretExpiresAt: nil + ) + let result = OAuthClientRegistrar.updatedAuthentication( + from: registration, + current: .clientSecretBasic(clientID: "old-id", clientSecret: "old-secret") + ) + let expected = OAuthConfiguration.TokenEndpointAuthentication.clientSecretBasic(clientID: "new-id", clientSecret: "new-secret") + #expect(result == expected) + } + + @Test("updatedAuthentication falls back to existing secret when not returned") + func testUpdatedAuthenticationFallsBackToCurrentSecret() { + let registration = OAuthClientRegistrationResponse( + clientID: "new-id", clientSecret: nil, + tokenEndpointAuthMethod: nil, clientSecretExpiresAt: nil + ) + let result = OAuthClientRegistrar.updatedAuthentication( + from: registration, + current: .clientSecretBasic(clientID: "old-id", clientSecret: "kept-secret") + ) + let expected = OAuthConfiguration.TokenEndpointAuthentication.clientSecretBasic(clientID: "new-id", clientSecret: "kept-secret") + #expect(result == expected) + } + } + +#endif diff --git a/Tests/MCPTests/OAuthComponentTestHelpers.swift b/Tests/MCPTests/OAuthComponentTestHelpers.swift new file mode 100644 index 00000000..a89cd0f9 --- /dev/null +++ b/Tests/MCPTests/OAuthComponentTestHelpers.swift @@ -0,0 +1,90 @@ +@preconcurrency import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Key-isolated mock URLProtocol. +/// +/// Each test creates a session with a unique key and registers its handler under that key. +/// Multiple suites run concurrently without interference because each test's requests carry +/// a different key, so they never read the wrong handler. +#if swift(>=6.1) + + actor IsolatedMockStorage { + static let shared = IsolatedMockStorage() + private var handlers: + [String: @Sendable (URLRequest) async throws -> (HTTPURLResponse, Data)] = [:] + + func set( + key: String, + handler: @escaping @Sendable (URLRequest) async throws -> (HTTPURLResponse, Data) + ) { + handlers[key] = handler + } + + func execute(key: String, request: URLRequest) async throws -> (HTTPURLResponse, Data) { + guard let handler = handlers[key] else { + throw NSError( + domain: "IsolatedMockError", code: 0, + userInfo: [NSLocalizedDescriptionKey: "No handler for key '\(key)'"]) + } + return try await handler(request) + } + } + + final class IsolatedMockURLProtocol: URLProtocol, @unchecked Sendable { + static let sessionKeyHeader = "X-Mock-Key" + + static func makeSession(key: String) -> URLSession { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [IsolatedMockURLProtocol.self] + config.httpAdditionalHeaders = [sessionKeyHeader: key] + return URLSession(configuration: config) + } + + static func setHandler( + key: String, + _ handler: @escaping @Sendable (URLRequest) async throws -> (HTTPURLResponse, Data) + ) async { + await IsolatedMockStorage.shared.set(key: key, handler: handler) + } + + func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { + let key = request.value(forHTTPHeaderField: IsolatedMockURLProtocol.sessionKeyHeader) + ?? "unknown" + return try await IsolatedMockStorage.shared.execute(key: key, request: request) + } + + override class func canInit(with request: URLRequest) -> Bool { true } + override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } + override func startLoading() { + Task { + do { + let (response, data) = try await self.executeHandler(for: request) + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: data) + client?.urlProtocolDidFinishLoading(self) + } catch { + client?.urlProtocol(self, didFailWithError: error) + } + } + } + override func stopLoading() {} + } + + // MARK: - Per-test session factory + + /// Creates an isolated URLSession + unique key for one test. + /// + /// Usage: + /// ```swift + /// let (session, key) = makeIsolatedSession() + /// await IsolatedMockURLProtocol.setHandler(key: key) { _ in (response, data) } + /// ``` + func makeIsolatedSession() -> (session: URLSession, key: String) { + let key = UUID().uuidString + return (IsolatedMockURLProtocol.makeSession(key: key), key) + } + +#endif diff --git a/Tests/MCPTests/OAuthDiscoveryClientTests.swift b/Tests/MCPTests/OAuthDiscoveryClientTests.swift new file mode 100644 index 00000000..31b74b83 --- /dev/null +++ b/Tests/MCPTests/OAuthDiscoveryClientTests.swift @@ -0,0 +1,213 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if swift(>=6.1) && !os(Linux) + + @Suite("OAuthDiscoveryClient", .serialized) + struct OAuthDiscoveryClientTests { + + let urlValidator = OAuthURLValidator(allowLoopbackHTTPForAuthorizationServer: true) + let metadataDiscovery = DefaultOAuthMetadataDiscovery() + + func makeClient() -> OAuthDiscoveryClient { + OAuthDiscoveryClient(metadataDiscovery: metadataDiscovery, urlValidator: urlValidator) + } + + func makeProtectedResourceBody(authorizationServers: [String]) throws -> Data { + let dict: [String: Any] = ["authorization_servers": authorizationServers] + return try JSONSerialization.data(withJSONObject: dict) + } + + func makeASMetadataBody(issuer: String) throws -> Data { + let dict: [String: Any] = [ + "issuer": issuer, + "token_endpoint": "https://auth.example.com/token", + "code_challenge_methods_supported": ["S256"], + ] + return try JSONSerialization.data(withJSONObject: dict) + } + + // MARK: - fetchProtectedResourceMetadata + + @Test("Returns metadata from first successful candidate") + func testFetchProtectedResourceMetadataSuccess() async throws { + let body = try makeProtectedResourceBody( + authorizationServers: ["https://auth.example.com"]) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: URL(string: "https://example.com/.well-known/oauth-protected-resource")!, + statusCode: 200, httpVersion: nil, headerFields: nil)! + return (response, body) + } + + let metadata = try await makeClient().fetchProtectedResourceMetadata( + candidates: [URL(string: "https://example.com/.well-known/oauth-protected-resource")!], + session: session + ) + let expected = OAuthProtectedResourceMetadata( + resource: nil, + authorizationServers: [URL(string: "https://auth.example.com")!], + scopesSupported: nil) + #expect(metadata == expected) + } + + @Test("Skips candidates that return non-2xx status") + func testFetchProtectedResourceMetadataSkipsNon2xx() async throws { + let body = try makeProtectedResourceBody( + authorizationServers: ["https://auth.example.com"]) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { request in + let statusCode = request.url?.lastPathComponent == "mcp" ? 404 : 200 + let response = HTTPURLResponse( + url: request.url!, statusCode: statusCode, + httpVersion: nil, headerFields: nil)! + return (response, statusCode == 200 ? body : Data()) + } + + let metadata = try await makeClient().fetchProtectedResourceMetadata( + candidates: [ + URL(string: "https://example.com/.well-known/oauth-protected-resource/mcp")!, + URL(string: "https://example.com/.well-known/oauth-protected-resource")!, + ], + session: session + ) + let expected = OAuthProtectedResourceMetadata( + resource: nil, + authorizationServers: [URL(string: "https://auth.example.com")!], + scopesSupported: nil) + #expect(metadata == expected) + } + + @Test("Skips candidates with empty authorizationServers array") + func testFetchProtectedResourceMetadataSkipsEmptyAuthServers() async throws { + let emptyBody = try makeProtectedResourceBody(authorizationServers: []) + let validBody = try makeProtectedResourceBody( + authorizationServers: ["https://auth.example.com"]) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { request in + let body = request.url?.lastPathComponent == "mcp" ? emptyBody : validBody + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return (response, body) + } + + let metadata = try await makeClient().fetchProtectedResourceMetadata( + candidates: [ + URL(string: "https://example.com/.well-known/oauth-protected-resource/mcp")!, + URL(string: "https://example.com/.well-known/oauth-protected-resource")!, + ], + session: session + ) + let expected = OAuthProtectedResourceMetadata( + resource: nil, + authorizationServers: [URL(string: "https://auth.example.com")!], + scopesSupported: nil) + #expect(metadata == expected) + } + + @Test("Throws metadataDiscoveryFailed when all candidates fail") + func testFetchProtectedResourceMetadataThrowsWhenAllFail() async throws { + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: 404, httpVersion: nil, headerFields: nil)! + return (response, Data()) + } + + await #expect(throws: OAuthAuthorizationError.self) { + try await makeClient().fetchProtectedResourceMetadata( + candidates: [ + URL(string: "https://example.com/.well-known/oauth-protected-resource")! + ], + session: session + ) + } + } + + // MARK: - fetchAuthorizationServerMetadata + + @Test("Returns server and metadata when issuer matches") + func testFetchAuthorizationServerMetadataSuccess() async throws { + let issuer = "https://auth.example.com" + let body = try makeASMetadataBody(issuer: issuer) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: URL(string: "\(issuer)/.well-known/oauth-authorization-server")!, + statusCode: 200, httpVersion: nil, headerFields: nil)! + return (response, body) + } + + let (server, metadata) = try await makeClient().fetchAuthorizationServerMetadata( + candidates: [URL(string: issuer)!], + session: session + ) + let expectedServer = URL(string: issuer)! + let expectedMetadata = OAuthAuthorizationServerMetadata( + issuer: URL(string: issuer), + authorizationEndpoint: nil, + tokenEndpoint: URL(string: "https://auth.example.com/token"), + registrationEndpoint: nil, + codeChallengeMethodsSupported: ["S256"], + tokenEndpointAuthMethodsSupported: nil, + clientIDMetadataDocumentSupported: nil) + #expect(server == expectedServer) + #expect(metadata == expectedMetadata) + } + + @Test("Skips candidate when issuer field does not match") + func testFetchAuthorizationServerMetadataSkipsIssuerMismatch() async throws { + let wrongIssuerBody = try makeASMetadataBody(issuer: "https://other.example.com") + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: URL(string: "https://auth.example.com")!, + statusCode: 200, httpVersion: nil, headerFields: nil)! + return (response, wrongIssuerBody) + } + + await #expect(throws: OAuthAuthorizationError.self) { + try await makeClient().fetchAuthorizationServerMetadata( + candidates: [URL(string: "https://auth.example.com")!], + session: session + ) + } + } + + @Test("Skips private IP candidates without making HTTP calls") + func testFetchAuthorizationServerMetadataSkipsPrivateIP() async throws { + let (session, _) = makeIsolatedSession() + await #expect(throws: OAuthAuthorizationError.self) { + try await makeClient().fetchAuthorizationServerMetadata( + candidates: [URL(string: "https://10.0.0.1")!], + session: session + ) + } + } + + @Test("Throws when all candidates return non-2xx") + func testFetchAuthorizationServerMetadataThrowsWhenAllFail() async throws { + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: 500, httpVersion: nil, headerFields: nil)! + return (response, Data()) + } + + await #expect(throws: OAuthAuthorizationError.self) { + try await makeClient().fetchAuthorizationServerMetadata( + candidates: [URL(string: "https://auth.example.com")!], + session: session + ) + } + } + } + +#endif diff --git a/Tests/MCPTests/OAuthTestScenarios.swift b/Tests/MCPTests/OAuthTestScenarios.swift new file mode 100644 index 00000000..62d74571 --- /dev/null +++ b/Tests/MCPTests/OAuthTestScenarios.swift @@ -0,0 +1,2361 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if swift(>=6.1) + + // MARK: - Scenario Context + + struct OAuthScenarioContext: Sendable { + let testEndpoint: URL + let oauthConfiguration: OAuthConfiguration + let messageData: Data + let expectedResponseData: Data? + let expectedCallCounts: [URL: Int] + let streaming: Bool + let sseInitializationTimeout: TimeInterval? + let expectedErrorSubstring: String? + let unexpectedErrorSubstrings: [String] + let secondMessageData: Data? + let secondExpectedResponseData: Data? + + init( + testEndpoint: URL, + oauthConfiguration: OAuthConfiguration, + messageData: Data, + expectedResponseData: Data? = nil, + expectedCallCounts: [URL: Int] = [:], + streaming: Bool = false, + sseInitializationTimeout: TimeInterval? = nil, + expectedErrorSubstring: String? = nil, + unexpectedErrorSubstrings: [String] = [], + secondMessageData: Data? = nil, + secondExpectedResponseData: Data? = nil + ) { + self.testEndpoint = testEndpoint + self.oauthConfiguration = oauthConfiguration + self.messageData = messageData + self.expectedResponseData = expectedResponseData + self.expectedCallCounts = expectedCallCounts + self.streaming = streaming + self.sseInitializationTimeout = sseInitializationTimeout + self.expectedErrorSubstring = expectedErrorSubstring + self.unexpectedErrorSubstrings = unexpectedErrorSubstrings + self.secondMessageData = secondMessageData + self.secondExpectedResponseData = secondExpectedResponseData + } + } + + // MARK: - Request Body Helper + + func readRequestBody(_ request: URLRequest) -> Data? { + if let data = request.httpBody { return data } + guard let stream = request.httpBodyStream else { return nil } + stream.open() + defer { stream.close() } + let bufferSize = 4096 + let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) + defer { buffer.deallocate() } + var data = Data() + while stream.hasBytesAvailable { + let bytesRead = stream.read(buffer, maxLength: bufferSize) + data.append(buffer, count: bytesRead) + } + return data + } + + // MARK: - Trackers + + actor ProviderTracker: Sendable { + var capturedContext: OAuthConfiguration.AccessTokenProviderContext? + func capture(_ context: OAuthConfiguration.AccessTokenProviderContext) { + capturedContext = context + } + } + + actor OrderTracker: Sendable { + var requests: [URL] = [] + func append(_ url: URL) { requests.append(url) } + func count() -> Int { requests.count } + } + + /// Dispenses tokens in order; repeats the last token when the list is exhausted. + actor TokenDispenser: Sendable { + private var tokens: [String] + init(_ tokens: String...) { self.tokens = Array(tokens) } + func next() -> String { + if tokens.count > 1 { return tokens.removeFirst() } + return tokens.first ?? "fallback-token" + } + } + + // MARK: - Scenario Definitions + + #if !canImport(FoundationNetworking) + + extension RequestHandlerStorage { + + // MARK: 1 - Client Credentials Retry After 401 + + func configureOAuthClientCredentialsRetryAfter401() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/test")! + let resourceMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/test")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 1) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-123") + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read", "files:write"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + #expect(request.httpMethod == "POST") + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains("resource=https%3A%2F%2Flocalhost%3A8080%2Ftest")) + #expect(body.contains("scope=files%3Aread")) + #expect(body.contains("client_id=test-client")) + return try await MockResponses.tokenSuccess( + accessToken: "access-token-123")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":1}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 2 - Scope Fallback to scopes_supported + + func configureOAuthScopeSelectionFallsBackToScopesSupported() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/fallback-scope")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/fallback-scope" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 21) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-fallback") + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:write", "files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Ffallback-scope")) + #expect(body.contains("scope=files%3Aread%20files%3Awrite")) + #expect(body.contains("client_id=test-client")) + return try await MockResponses.tokenSuccess( + accessToken: "access-token-fallback")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":21}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 3 - Scope Omitted When No Hints + + func configureOAuthScopeOmittedWhenNoHints() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/no-scope-hints")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/no-scope-hints" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 22) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-no-scope") + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Fno-scope-hints")) + #expect(!body.contains("scope=")) + #expect(body.contains("client_id=test-client")) + return try await MockResponses.tokenSuccess( + accessToken: "access-token-no-scope")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":22}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 4 - Resource Parameter in Authorization Code Flow + + func configureOAuthResourceParameterInAuthorizationAndToken() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp?foo=bar")! + let canonicalResource = "https://localhost:8080/public/mcp" + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let authorizationEndpointURL = URL( + string: "https://localhost:8080/oauth/authorize")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 23) + + requestHandler = { + [testEndpoint, canonicalResource, resourceMetadataURL, asMetadataURL, authorizationEndpointURL, tokenEndpointURL, finalResponseData] + request in + guard let url = request.url else { + throw MockResponses.mockError("Missing request URL") + } + + if url.scheme == authorizationEndpointURL.scheme, + url.host == authorizationEndpointURL.host, + url.port == authorizationEndpointURL.port, + url.path == authorizationEndpointURL.path + { + let queryItems = + URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems + ?? [] + #expect( + queryItems.contains(where: { + $0.name == "resource" && $0.value == canonicalResource + })) + #expect( + queryItems.contains(where: { + $0.name == "response_type" && $0.value == "code" + })) + #expect( + queryItems.contains(where: { + $0.name == "client_id" && $0.value == "test-client" + })) + #expect( + queryItems.contains(where: { + $0.name == "scope" && $0.value == "files:read" + })) + let state = queryItems.first(where: { $0.name == "state" })?.value + let redirectURI = + queryItems.first(where: { $0.name == "redirect_uri" })?.value + #expect(state != nil) + #expect(redirectURI != nil) + + var redirectComponents = URLComponents(string: redirectURI ?? "") + var redirectQueryItems = redirectComponents?.queryItems ?? [] + redirectQueryItems.append(.init(name: "code", value: "test")) + redirectQueryItems.append(.init(name: "state", value: state)) + redirectComponents?.queryItems = redirectQueryItems + let locationValue = + redirectComponents?.url?.absoluteString + ?? "http://127.0.0.1:3000/callback?code=test&state=\(state ?? "")" + let response = HTTPURLResponse( + url: url, statusCode: 302, httpVersion: "HTTP/1.1", + headerFields: ["Location": locationValue])! + return (response, Data()) + } + + switch url { + case testEndpoint: + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-resource") + #expect(request.url?.query == "foo=bar") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + + case resourceMetadataURL: + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + + case asMetadataURL: + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["S256"] + )(request) + + case tokenEndpointURL: + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=authorization_code")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Fpublic%2Fmcp")) + #expect(!body.contains("%3Ffoo%3Dbar")) + #expect(body.contains("scope=files%3Aread")) + #expect(body.contains("client_id=test-client")) + #expect(body.contains("code_verifier=")) + #expect(body.contains("code=test")) + return try await MockResponses.tokenSuccess( + accessToken: "access-token-resource")(request) + + default: + throw MockResponses.mockError("Unexpected URL: \(url.absoluteString)") + } + } + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client") + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":23}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 5 - Rejects Authorization Without PKCE Metadata + + func configureOAuthRejectsAuthorizationWithoutPKCEMetadata() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/pkce-metadata-missing")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/pkce-metadata-missing" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let authorizationEndpointURL = URL( + string: "https://localhost:8080/oauth/authorize")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = { + [testEndpoint, resourceMetadataURL, asMetadataURL, authorizationEndpointURL, tokenEndpointURL] + request in + guard let url = request.url else { + throw MockResponses.mockError("Missing request URL") + } + + if url.scheme == authorizationEndpointURL.scheme, + url.host == authorizationEndpointURL.host, + url.port == authorizationEndpointURL.port, + url.path == authorizationEndpointURL.path + { + let response = HTTPURLResponse( + url: url, statusCode: 302, httpVersion: "HTTP/1.1", + headerFields: [ + "Location": "http://127.0.0.1:3000/callback?code=test" + ])! + return (response, Data()) + } + + switch url { + case testEndpoint: + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + + case resourceMetadataURL: + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + + case asMetadataURL: + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize" + )(request) + + case tokenEndpointURL: + return try await MockResponses.tokenSuccess( + accessToken: "should-not-be-issued")(request) + + default: + throw MockResponses.mockError("Unexpected URL: \(url.absoluteString)") + } + } + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client") + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":24}"#.data(using: .utf8)!, + expectedCallCounts: [tokenEndpointURL: 0], + expectedErrorSubstring: "code_challenge_methods_supported" + ) + } + + // MARK: 6 - Rejects Authorization Without S256 PKCE + + func configureOAuthRejectsAuthorizationWithoutS256PKCE() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/pkce-s256-missing")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/pkce-s256-missing" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let authorizationEndpointURL = URL( + string: "https://localhost:8080/oauth/authorize")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = { + [testEndpoint, resourceMetadataURL, asMetadataURL, authorizationEndpointURL, tokenEndpointURL] + request in + guard let url = request.url else { + throw MockResponses.mockError("Missing request URL") + } + + if url.scheme == authorizationEndpointURL.scheme, + url.host == authorizationEndpointURL.host, + url.port == authorizationEndpointURL.port, + url.path == authorizationEndpointURL.path + { + let response = HTTPURLResponse( + url: url, statusCode: 302, httpVersion: "HTTP/1.1", + headerFields: [ + "Location": "http://127.0.0.1:3000/callback?code=test" + ])! + return (response, Data()) + } + + switch url { + case testEndpoint: + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + + case resourceMetadataURL: + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + + case asMetadataURL: + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["plain"] + )(request) + + case tokenEndpointURL: + return try await MockResponses.tokenSuccess( + accessToken: "should-not-be-issued")(request) + + default: + throw MockResponses.mockError("Unexpected URL: \(url.absoluteString)") + } + } + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client") + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":25}"#.data(using: .utf8)!, + expectedCallCounts: [tokenEndpointURL: 0], + expectedErrorSubstring: "must support PKCE S256" + ) + } + + // MARK: 7 - Rejects Authorization Response Redirect Mismatch + + func configureOAuthRejectsAuthorizationResponseRedirectMismatch() + -> OAuthScenarioContext + { + let testEndpoint = URL( + string: "https://localhost:8080/authorization-redirect-mismatch")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/authorization-redirect-mismatch" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let authorizationEndpointURL = URL( + string: "https://localhost:8080/oauth/authorize")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = { + [testEndpoint, resourceMetadataURL, asMetadataURL, authorizationEndpointURL, tokenEndpointURL] + request in + guard let url = request.url else { + throw MockResponses.mockError("Missing request URL") + } + + if url.scheme == authorizationEndpointURL.scheme, + url.host == authorizationEndpointURL.host, + url.port == authorizationEndpointURL.port, + url.path == authorizationEndpointURL.path + { + let queryItems = + URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems + ?? [] + let state = + queryItems.first(where: { $0.name == "state" })?.value ?? "" + let locationValue = + "https://evil.example.com/callback?code=test&state=\(state)" + let response = HTTPURLResponse( + url: url, statusCode: 302, httpVersion: "HTTP/1.1", + headerFields: ["Location": locationValue])! + return (response, Data()) + } + + switch url { + case testEndpoint: + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + + case resourceMetadataURL: + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + + case asMetadataURL: + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["S256"] + )(request) + + case tokenEndpointURL: + return try await MockResponses.tokenSuccess( + accessToken: "should-not-be-issued")(request) + + default: + throw MockResponses.mockError("Unexpected URL: \(url.absoluteString)") + } + } + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client") + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":26}"#.data(using: .utf8)!, + expectedCallCounts: [tokenEndpointURL: 0], + expectedErrorSubstring: "redirect URI mismatch" + ) + } + + // MARK: 8 - Rejects Authorization Response State Mismatch + + func configureOAuthRejectsAuthorizationResponseStateMismatch() + -> OAuthScenarioContext + { + let testEndpoint = URL( + string: "https://localhost:8080/authorization-state-mismatch")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/authorization-state-mismatch" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let authorizationEndpointURL = URL( + string: "https://localhost:8080/oauth/authorize")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = { + [testEndpoint, resourceMetadataURL, asMetadataURL, authorizationEndpointURL, tokenEndpointURL] + request in + guard let url = request.url else { + throw MockResponses.mockError("Missing request URL") + } + + if url.scheme == authorizationEndpointURL.scheme, + url.host == authorizationEndpointURL.host, + url.port == authorizationEndpointURL.port, + url.path == authorizationEndpointURL.path + { + let queryItems = + URLComponents(url: url, resolvingAgainstBaseURL: false)?.queryItems + ?? [] + let redirectURI = + queryItems.first(where: { $0.name == "redirect_uri" })?.value + var redirectComponents = URLComponents(string: redirectURI ?? "") + var redirectQueryItems = redirectComponents?.queryItems ?? [] + redirectQueryItems.append(.init(name: "code", value: "test")) + redirectQueryItems.append( + .init(name: "state", value: "unexpected-state")) + redirectComponents?.queryItems = redirectQueryItems + let locationValue = + redirectComponents?.url?.absoluteString + ?? "http://127.0.0.1:3000/callback?code=test&state=unexpected-state" + let response = HTTPURLResponse( + url: url, statusCode: 302, httpVersion: "HTTP/1.1", + headerFields: ["Location": locationValue])! + return (response, Data()) + } + + switch url { + case testEndpoint: + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + + case resourceMetadataURL: + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + + case asMetadataURL: + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["S256"] + )(request) + + case tokenEndpointURL: + return try await MockResponses.tokenSuccess( + accessToken: "should-not-be-issued")(request) + + default: + throw MockResponses.mockError("Unexpected URL: \(url.absoluteString)") + } + } + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client") + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":27}"#.data(using: .utf8)!, + expectedCallCounts: [tokenEndpointURL: 0], + expectedErrorSubstring: "state mismatch" + ) + } + + // MARK: 9 - Access Token Only Via Authorization Header + + func configureOAuthAccessTokenOnlyViaAuthorizationHeader() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/test?foo=bar")! + let resourceMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/test")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 12) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-xyz") + #expect(request.url?.query == "foo=bar") + #expect( + !(request.url?.absoluteString.contains("access_token=") ?? false)) + let requestBody = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(!requestBody.contains("access_token=")) + #expect( + request.value(forHTTPHeaderField: "Content-Type") + == "application/json") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "access-token-xyz"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":12}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 10 - Authorization Header For Every Request In Session + + func configureOAuthAuthorizationHeaderForEveryRequestInSession() + -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/session-auth")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/session-auth")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let firstResponseData = #"{"jsonrpc":"2.0","result":{"ok":true},"id":31}"#.data( + using: .utf8)! + let secondResponseData = #"{"jsonrpc":"2.0","result":{"ok":true},"id":32}"#.data( + using: .utf8)! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { + [firstResponseData, secondResponseData] request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer session-access-token") + #expect( + !(request.url?.absoluteString.contains("access_token=") ?? false)) + let requestBody = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(!requestBody.contains("access_token=")) + + let responseBody: Data + if requestBody.contains("\"id\":31") { + responseBody = firstResponseData + } else if requestBody.contains("\"id\":32") { + responseBody = secondResponseData + } else { + throw MockResponses.mockError( + "Unexpected JSON-RPC body: \(requestBody)") + } + + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Type": "application/json", + "MCP-Session-Id": "session-123", + ])! + return (response, responseBody) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Fsession-auth")) + #expect(body.contains("scope=files%3Aread")) + return try await MockResponses.tokenSuccess( + accessToken: "session-access-token")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":31}"#.data(using: .utf8)!, + expectedResponseData: firstResponseData, + expectedCallCounts: [ + testEndpoint: 3, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ], + secondMessageData: #"{"jsonrpc":"2.0","method":"ping","id":32}"#.data( + using: .utf8)!, + secondExpectedResponseData: secondResponseData + ) + } + + // MARK: 11 - Streaming GET Uses Authorization Header Only + + func configureOAuthStreamingGETUsesAuthorizationHeaderOnly() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/stream-auth?foo=bar")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/stream-auth")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let sseEventData = "id: evt-1\ndata: {\"stream\":\"ok\"}\n\n".data(using: .utf8)! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { [sseEventData] request in + if request.httpMethod == "GET" { + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer stream-access-token") + #expect(request.url?.query == "foo=bar") + #expect( + !(request.url?.absoluteString.contains("access_token=") ?? false)) + #expect( + request.value(forHTTPHeaderField: "MCP-Session-Id") + == "stream-session-id") + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/event-stream"])! + return (response, sseEventData) + } + + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer stream-access-token") + #expect( + !(request.url?.absoluteString.contains("access_token=") ?? false)) + let requestBody = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(!requestBody.contains("access_token=")) + + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Type": "text/plain", + "MCP-Session-Id": "stream-session-id", + ])! + return (response, Data()) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Fstream-auth")) + #expect(body.contains("scope=files%3Aread")) + return try await MockResponses.tokenSuccess( + accessToken: "stream-access-token")(request) + }, + ]) + + let expectedEventPayload = #"{"stream":"ok"}"#.data(using: .utf8)! + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":41}"#.data( + using: .utf8)!, + expectedResponseData: expectedEventPayload, + streaming: true, + sseInitializationTimeout: 1 + ) + } + + // MARK: 12 - Rejects Non-Bearer Token Type + + func configureOAuthRejectsNonBearerTokenType() -> OAuthScenarioContext { + let testEndpoint = URL( + string: "https://localhost:8080/non-bearer-token-type")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/non-bearer-token-type" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenResponse( + accessToken: "non-bearer-token", tokenType: "DPoP"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":51}"#.data( + using: .utf8)!, + expectedErrorSubstring: "Token response is invalid" + ) + } + + // MARK: 13 - Token Endpoint Failure Redacts Response Body + + func configureOAuthTokenEndpointFailureRedactsResponseBody() -> OAuthScenarioContext { + let testEndpoint = URL( + string: "https://localhost:8080/token-error-redaction")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/token-error-redaction" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenError( + error: "invalid_client", + errorDescription: "leaked-secret-value", + extraFields: ["access_token": "should-not-leak"] + ), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":52}"#.data( + using: .utf8)!, + expectedErrorSubstring: "oauth_error: invalid_client", + unexpectedErrorSubstrings: ["leaked-secret-value", "should-not-leak"] + ) + } + + // MARK: 14 - Rejects Non-HTTPS Token Endpoint + + func configureOAuthRejectsNonHTTPSTokenEndpoint() -> OAuthScenarioContext { + let testEndpoint = URL( + string: "https://localhost:8080/non-https-token-endpoint")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/non-https-token-endpoint" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "http://localhost:8080/oauth/token" + ), + ]) + + var oauth = OAuthConfiguration( + authentication: .none(clientID: "test-client")) + oauth.allowLoopbackHTTPAuthorizationServerEndpoints = false + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: oauth, + messageData: #"{"jsonrpc":"2.0","method":"ping","id":53}"#.data( + using: .utf8)!, + expectedErrorSubstring: "Token endpoint must use https" + ) + } + + // MARK: 15 - Allows Loopback HTTP Authorization Server Endpoints + + func configureOAuthAllowsLoopbackHTTPAuthorizationServerEndpoints() + -> OAuthScenarioContext + { + let testEndpoint = URL( + string: "https://localhost:8080/loopback-http-auth-server-enabled")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/loopback-http-auth-server-enabled" + )! + let asMetadataURL = URL( + string: "http://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "http://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 54) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") != nil { + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + } + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["http://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "http://localhost:8080/auth", + tokenEndpoint: "http://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + return try await MockResponses.tokenSuccess( + accessToken: "loopback-http-token")(request) + }, + ]) + + var oauth = OAuthConfiguration( + authentication: .none(clientID: "test-client")) + oauth.allowLoopbackHTTPAuthorizationServerEndpoints = true + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: oauth, + messageData: #"{"jsonrpc":"2.0","method":"ping","id":54}"#.data( + using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [tokenEndpointURL: 1] + ) + } + + // MARK: 16 - Access Token Provider Receives Discovery Context + + func configureOAuthAccessTokenProviderReceivesDiscoveryContext() + -> (OAuthScenarioContext, ProviderTracker) + { + let testEndpoint = URL(string: "https://localhost:8080/test")! + let resourceMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource/test")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 2) + let providerTracker = ProviderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, + scope: "files:read files:write" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer provider-access-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read", "files:write"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["S256"] + ), + tokenEndpointURL: { request in + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, Data()) + }, + ]) + + let oauthConfiguration = OAuthConfiguration( + authentication: .none(clientID: "test-client"), + accessTokenProvider: { [providerTracker] providerContext, _ in + await providerTracker.capture(providerContext) + return "provider-access-token" + } + ) + + let context = OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: oauthConfiguration, + messageData: #"{"jsonrpc":"2.0","method":"ping","id":2}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 0, + ] + ) + return (context, providerTracker) + } + + // MARK: 17 - Discovery Uses Header Resource Metadata + + func configureOAuthDiscoveryUsesHeaderResourceMetadata() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let headerMetadataURL = URL( + string: "https://localhost:8080/custom-metadata")! + let fallbackPathMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let fallbackRootMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 3) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: headerMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-123") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + headerMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + fallbackPathMetadataURL: MockResponses.httpError(statusCode: 500), + fallbackRootMetadataURL: MockResponses.httpError(statusCode: 500), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "access-token-123"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":3}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, headerMetadataURL: 1, + fallbackPathMetadataURL: 0, fallbackRootMetadataURL: 0, + ] + ) + } + + // MARK: 18 - Discovery Fallback Well-Known Order + + func configureOAuthDiscoveryFallbackWellKnownOrder() + -> (OAuthScenarioContext, OrderTracker) + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let fallbackPathMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let fallbackRootMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 4) + let tracker = OrderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + scope: "files:read")(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer access-token-456") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + fallbackPathMetadataURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + fallbackRootMetadataURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + )(request) + }, + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "access-token-456"), + ]) + + let context = OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":4}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData + ) + return (context, tracker) + } + + // MARK: 19 - Discovery Fails When Metadata Unavailable + + func configureOAuthDiscoveryFailsWhenMetadataUnavailable() + -> (OAuthScenarioContext, OrderTracker) + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let fallbackPathMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let fallbackRootMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-protected-resource")! + let tracker = OrderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + scope: "files:read")(request) + }, + fallbackPathMetadataURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + fallbackRootMetadataURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + ]) + + let context = OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":7}"#.data(using: .utf8)!, + expectedErrorSubstring: "metadata" + ) + return (context, tracker) + } + + // MARK: 20 - AS Metadata Discovery Order For Path Issuer + + func configureOAuthASMetadataDiscoveryOrderForPathIssuer() + -> (OAuthScenarioContext, OrderTracker) + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let authorizationServer = URL(string: "https://localhost:8080/tenant1")! + let asMetadataOAuthInsertedURL = URL( + string: + "https://localhost:8080/.well-known/oauth-authorization-server/tenant1")! + let asMetadataOIDCInsertedURL = URL( + string: "https://localhost:8080/.well-known/openid-configuration/tenant1")! + let asMetadataOIDCAppendedURL = URL( + string: "https://localhost:8080/tenant1/.well-known/openid-configuration")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 5) + let tracker = OrderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer path-issuer-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/tenant1"], + scopesSupported: ["files:read"] + ), + asMetadataOAuthInsertedURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + asMetadataOIDCInsertedURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + asMetadataOIDCAppendedURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080/tenant1", + tokenEndpoint: "https://localhost:8080/oauth/token" + )(request) + }, + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect(body.contains("scope=files%3Aread")) + return try await MockResponses.tokenSuccess( + accessToken: "path-issuer-token")(request) + }, + authorizationServer: { _ in + throw MockResponses.mockError( + "Unexpected direct issuer request") + }, + ]) + + let context = OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":5}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData + ) + return (context, tracker) + } + + // MARK: 21 - AS Metadata Discovery Order For Root Issuer + + func configureOAuthASMetadataDiscoveryOrderForRootIssuer() + -> (OAuthScenarioContext, OrderTracker) + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let authorizationServer = URL(string: "https://localhost:8080")! + let asMetadataOAuthURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server")! + let asMetadataOIDCURL = URL( + string: "https://localhost:8080/.well-known/openid-configuration")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 6) + let tracker = OrderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer root-issuer-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080"], + scopesSupported: ["files:read"] + ), + asMetadataOAuthURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.httpError(statusCode: 404)(request) + }, + asMetadataOIDCURL: { [tracker] request in + await tracker.append(request.url!) + return try await MockResponses.asMetadata( + issuer: "https://localhost:8080", + tokenEndpoint: "https://localhost:8080/oauth/token" + )(request) + }, + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "root-issuer-token"), + authorizationServer: { _ in + throw MockResponses.mockError( + "Unexpected direct issuer request") + }, + ]) + + let context = OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":6}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData + ) + return (context, tracker) + } + + // MARK: 22 - Registration Prefers CIMD When Advertised + + func configureOAuthRegistrationPrefersCIMDWhenAdvertised() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let registrationEndpointURL = URL( + string: "https://localhost:8080/register")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let clientMetadataDocumentID = "https://client.example.com/metadata.json" + let finalResponseData = MockResponses.jsonRPCResult(id: 8) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer cimd-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + registrationEndpoint: "https://localhost:8080/register", + clientIDMetadataDocumentSupported: true + ), + registrationEndpointURL: MockResponses.httpError(statusCode: 500), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "client_id=https%3A%2F%2Fclient.example.com%2Fmetadata.json")) + return try await MockResponses.tokenSuccess( + accessToken: "cimd-token")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .none(clientID: clientMetadataDocumentID)), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":8}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [registrationEndpointURL: 0] + ) + } + + // MARK: 23 - Pre-Registration Uses Static Credentials + + func configureOAuthPreRegistrationUsesStaticCredentials() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let registrationEndpointURL = URL( + string: "https://localhost:8080/register")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 13) + + let expectedClientID = "pre-registered-client" + let expectedClientSecret = "pre-registered-secret" + let expectedBasic = Data( + "\(expectedClientID):\(expectedClientSecret)".utf8 + ).base64EncodedString() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer pre-registered-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + registrationEndpoint: "https://localhost:8080/register" + ), + registrationEndpointURL: MockResponses.httpError(statusCode: 500), + tokenEndpointURL: { [expectedBasic] request in + #expect(request.httpMethod == "POST") + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Basic \(expectedBasic)") + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect( + body.contains( + "resource=https%3A%2F%2Flocalhost%3A8080%2Fpublic%2Fmcp")) + #expect(body.contains("scope=files%3Aread")) + #expect(!body.contains("client_id=")) + #expect(!body.contains("client_secret=")) + return try await MockResponses.tokenSuccess( + accessToken: "pre-registered-token")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .clientSecretBasic( + clientID: expectedClientID, + clientSecret: expectedClientSecret + )), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":13}"#.data( + using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [registrationEndpointURL: 0] + ) + } + + // MARK: 24 - Registration Falls Back to Dynamic Registration (CIMD Not Advertised) + + func configureOAuthRegistrationFallsBackToDynamicRegistrationCIMDNotAdvertised() + -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let registrationEndpointURL = URL( + string: "https://localhost:8080/register")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let clientMetadataDocumentID = "https://client.example.com/metadata.json" + let finalResponseData = MockResponses.jsonRPCResult(id: 9) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer dynamic-registration-token") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + registrationEndpoint: "https://localhost:8080/register", + clientIDMetadataDocumentSupported: false + ), + registrationEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("\"grant_types\":[\"client_credentials\"]")) + #expect(body.contains("\"token_endpoint_auth_method\":\"none\"")) + #expect(!body.contains("\"response_types\"")) + #expect(!body.contains("redirect_uris")) + #expect(!body.contains("authorization_code")) + return try await MockResponses.registrationSuccess( + clientID: "registered-client")(request) + }, + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect(body.contains("client_id=registered-client")) + #expect( + !body.contains( + "client_id=https%3A%2F%2Fclient.example.com%2Fmetadata.json")) + return try await MockResponses.tokenSuccess( + accessToken: "dynamic-registration-token")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .none(clientID: clientMetadataDocumentID)), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":9}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [registrationEndpointURL: 1] + ) + } + + // MARK: 25 - Registration Falls Back to Dynamic Registration (CIMD Capability Missing) + + func configureOAuthRegistrationFallsBackToDynamicRegistrationCIMDCapabilityMissing() + -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let registrationEndpointURL = URL( + string: "https://localhost:8080/register")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let clientMetadataDocumentID = "https://client.example.com/metadata.json" + let finalResponseData = MockResponses.jsonRPCResult(id: 14) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + } + #expect( + request.value(forHTTPHeaderField: "Authorization") + == "Bearer dynamic-registration-token-missing-capability") + return try await MockResponses.jsonSuccess(body: finalResponseData)( + request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + registrationEndpoint: "https://localhost:8080/register" + ), + registrationEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("\"grant_types\":[\"client_credentials\"]")) + #expect(!body.contains("redirect_uris")) + return try await MockResponses.registrationSuccess( + clientID: "registered-client-2")(request) + }, + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + #expect(body.contains("grant_type=client_credentials")) + #expect(body.contains("client_id=registered-client-2")) + return try await MockResponses.tokenSuccess( + accessToken: "dynamic-registration-token-missing-capability")( + request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .none(clientID: clientMetadataDocumentID)), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":14}"#.data( + using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [registrationEndpointURL: 1] + ) + } + + // MARK: 26 - Registration Missing Mechanism Returns Actionable Error + + func configureOAuthRegistrationMissingMechanismReturnsActionableError() + -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let clientMetadataDocumentID = "https://client.example.com/metadata.json" + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + clientIDMetadataDocumentSupported: false + ), + tokenEndpointURL: MockResponses.httpError(statusCode: 500), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .none(clientID: clientMetadataDocumentID)), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":10}"#.data( + using: .utf8)!, + expectedErrorSubstring: + "Authorization server does not support Client ID Metadata Documents" + ) + } + + // MARK: 27 - CIMD Rejects Non-HTTPS Client ID URL + + func configureOAuthCIMDRejectsNonHTTPSClientIDURL() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let invalidClientID = "http://client.example.com/metadata.json" + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + clientIDMetadataDocumentSupported: true + ), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + authentication: .none(clientID: invalidClientID)), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":11}"#.data( + using: .utf8)!, + expectedErrorSubstring: + "Client ID metadata document URL must use https and include a path" + ) + } + + // MARK: 28 - Rejects Insecure MCP Endpoint URL + + func configureOAuthRejectsInsecureMCPEndpointURL() -> OAuthScenarioContext { + let insecureEndpoint = URL(string: "http://example.com/public/mcp")! + + return OAuthScenarioContext( + testEndpoint: insecureEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":12}"#.data( + using: .utf8)!, + expectedErrorSubstring: + "MCP endpoint must use https or loopback http" + ) + } + + // MARK: 30 - PRM Cache Invalidated When resource_metadata URL Changes + + func configureOAuthPRMCacheInvalidatedOnResourceMetadataURLChange() + -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/cache-invalidation")! + let resourceMetadataURL_A = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/cache-a")! + let resourceMetadataURL_B = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/cache-b")! + let asMetadataURL_A = URL( + string: + "https://localhost:8080/.well-known/oauth-authorization-server/auth-a")! + let asMetadataURL_B = URL( + string: + "https://localhost:8080/.well-known/oauth-authorization-server/auth-b")! + let tokenEndpointURL_A = URL( + string: "https://localhost:8080/oauth/token-a")! + let tokenEndpointURL_B = URL( + string: "https://localhost:8080/oauth/token-b")! + let firstResponseData = MockResponses.jsonRPCResult(id: 71) + let secondResponseData = MockResponses.jsonRPCResult(id: 72) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + let authHeader = request.value(forHTTPHeaderField: "Authorization") + let bodyStr = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + if authHeader == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL_A, scope: "files:read" + )(request) + } else if authHeader == "Bearer token-a" { + if bodyStr.contains("\"id\":71") { + return try await MockResponses.jsonSuccess( + body: firstResponseData)(request) + } else { + // Second message with old token → return 401 with new URL_B + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL_B, scope: "files:read" + )(request) + } + } else if authHeader == "Bearer token-b" { + return try await MockResponses.jsonSuccess( + body: secondResponseData)(request) + } + throw MockResponses.mockError( + "Unexpected Authorization header: \(String(describing: authHeader))") + }, + resourceMetadataURL_A: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth-a"], + scopesSupported: ["files:read"] + ), + resourceMetadataURL_B: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth-b"], + scopesSupported: ["files:read"] + ), + asMetadataURL_A: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth-a", + tokenEndpoint: "https://localhost:8080/oauth/token-a" + ), + asMetadataURL_B: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth-b", + tokenEndpoint: "https://localhost:8080/oauth/token-b" + ), + tokenEndpointURL_A: MockResponses.tokenSuccess(accessToken: "token-a"), + tokenEndpointURL_B: MockResponses.tokenSuccess(accessToken: "token-b"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":71}"#.data(using: .utf8)!, + expectedResponseData: firstResponseData, + expectedCallCounts: [ + resourceMetadataURL_A: 1, + resourceMetadataURL_B: 1, + tokenEndpointURL_A: 1, + tokenEndpointURL_B: 1, + ], + secondMessageData: #"{"jsonrpc":"2.0","method":"ping","id":72}"#.data( + using: .utf8)!, + secondExpectedResponseData: secondResponseData + ) + } + + // MARK: 29 - Rejects Non-Loopback HTTP Redirect URI + + func configureOAuthRejectsNonLoopbackHTTPRedirectURI() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/public/mcp")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/public/mcp")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL, scope: "files:read" + )(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + scopesSupported: ["files:read"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + authorizationEndpoint: "https://localhost:8080/oauth/authorize", + codeChallengeMethodsSupported: ["S256"] + ), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init( + grantType: .authorizationCode, + authentication: .none(clientID: "test-client"), + authorizationRedirectURI: URL( + string: "http://evil.example.com/callback")! + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":13}"#.data( + using: .utf8)!, + expectedErrorSubstring: + "Redirect URI must use https or loopback http and must not include fragments" + ) + } + + // MARK: 31 - Resource Uses PRM Resource Field (Gap 1) + + func configureOAuthResourceUsesPRMResourceField() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/mcp/tools")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/mcp/tools")! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 91) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL)(request) + } + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + // PRM resource = "https://localhost:8080" (origin only, no path) + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"], + resource: "https://localhost:8080" + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + // Must use the PRM resource field, not the specific endpoint path + #expect(body.contains("resource=https%3A%2F%2Flocalhost%3A8080")) + #expect(!body.contains("resource=https%3A%2F%2Flocalhost%3A8080%2Fmcp")) + return try await MockResponses.tokenSuccess( + accessToken: "access-token-prm-resource")(request) + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":91}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, resourceMetadataURL: 1, asMetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 32 - Second Authorization Server Tried When First Fails (Gap 2) + + func configureOAuthSecondAuthorizationServerTriedWhenFirstFails() -> OAuthScenarioContext + { + let testEndpoint = URL(string: "https://localhost:8080/test-as-fallback")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/test-as-fallback" + )! + // AS1 discovery URLs — all return 404 + let as1MetadataURL1 = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/as1")! + let as1MetadataURL2 = URL( + string: "https://localhost:8080/.well-known/openid-configuration/as1")! + let as1MetadataURL3 = URL( + string: "https://localhost:8080/as1/.well-known/openid-configuration")! + // AS2 first discovery URL — returns valid metadata + let as2MetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/as2")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token-as2")! + let finalResponseData = MockResponses.jsonRPCResult(id: 92) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL)(request) + } + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: [ + "https://localhost:8080/as1", + "https://localhost:8080/as2", + ] + ), + // AS1: all well-known URLs return 404 + as1MetadataURL1: MockResponses.httpError(statusCode: 404), + as1MetadataURL2: MockResponses.httpError(statusCode: 404), + as1MetadataURL3: MockResponses.httpError(statusCode: 404), + // AS2: first well-known URL returns valid metadata + as2MetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/as2", + tokenEndpoint: "https://localhost:8080/oauth/token-as2" + ), + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "access-token-from-as2"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":92}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, + resourceMetadataURL: 1, + as1MetadataURL1: 1, + as1MetadataURL2: 1, + as1MetadataURL3: 1, + as2MetadataURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 33 - Re-Registration After Client Secret Expiry (Gap 3) + + func configureOAuthReRegistersAfterClientSecretExpiry() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/test-secret-expiry")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/test-secret-expiry" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let registrationEndpointURL = URL(string: "https://localhost:8080/register")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let firstResponseData = MockResponses.jsonRPCResult(id: 93) + let secondResponseData = MockResponses.jsonRPCResult(id: 94) + let regTracker = OrderTracker() + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + let auth = request.value(forHTTPHeaderField: "Authorization") + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + if auth == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL)(request) + } + if auth == "Bearer token-v1" && body.contains("\"id\":93") { + return try await MockResponses.jsonSuccess(body: firstResponseData)( + request) + } + // Second request with old token → trigger re-auth + if auth == "Bearer token-v1" && body.contains("\"id\":94") { + return try await MockResponses.bearerChallenge()(request) + } + if auth == "Bearer token-v2" { + return try await MockResponses.jsonSuccess(body: secondResponseData)( + request) + } + throw MockResponses.mockError( + "Unexpected auth: \(String(describing: auth))") + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token", + registrationEndpoint: "https://localhost:8080/register" + ), + registrationEndpointURL: { [regTracker] request in + await regTracker.append(request.url!) + let count = await regTracker.count() + // Return different client IDs per registration call + let clientID = count == 1 ? "client-v1" : "client-v2" + // client_secret_expires_at = 1 (Unix epoch — always in the past) + let dict: [String: Any] = [ + "client_id": clientID, + "client_secret_expires_at": 1, + ] + let data = try JSONSerialization.data(withJSONObject: dict) + let response = HTTPURLResponse( + url: request.url!, statusCode: 201, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, data) + }, + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + if body.contains("client_id=client-v1") { + return try await MockResponses.tokenSuccess(accessToken: "token-v1")( + request) + } + if body.contains("client_id=client-v2") { + return try await MockResponses.tokenSuccess(accessToken: "token-v2")( + request) + } + throw MockResponses.mockError("Unexpected client_id in token request") + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "anon")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":93}"#.data(using: .utf8)!, + expectedResponseData: firstResponseData, + expectedCallCounts: [ + registrationEndpointURL: 2, + tokenEndpointURL: 2, + ], + secondMessageData: #"{"jsonrpc":"2.0","method":"ping","id":94}"#.data( + using: .utf8)!, + secondExpectedResponseData: secondResponseData + ) + } + + // MARK: 34 - Issuer Mismatch Causes Next Discovery URL Variant to Be Tried (Gap 4) + + func configureOAuthIssuerMismatchTriesNextURLVariant() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/test-issuer-check")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/test-issuer-check" + )! + // Discovery URL 1 returns wrong issuer → skipped + let wrongIssuerURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + // Discovery URL 2 returns correct issuer → used + let correctIssuerURL = URL( + string: "https://localhost:8080/.well-known/openid-configuration/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let finalResponseData = MockResponses.jsonRPCResult(id: 95) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + if request.value(forHTTPHeaderField: "Authorization") == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL)(request) + } + return try await MockResponses.jsonSuccess(body: finalResponseData)(request) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"] + ), + // URL 1: returns metadata with wrong issuer field → SDK skips it + wrongIssuerURL: MockResponses.asMetadata( + issuer: "https://evil.example.com", + tokenEndpoint: "https://evil.example.com/token" + ), + // URL 2: returns metadata with correct issuer → SDK accepts it + correctIssuerURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: MockResponses.tokenSuccess( + accessToken: "access-token-after-issuer-check"), + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + oauthConfiguration: .init(authentication: .none(clientID: "test-client")), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":95}"#.data(using: .utf8)!, + expectedResponseData: finalResponseData, + expectedCallCounts: [ + testEndpoint: 2, + resourceMetadataURL: 1, + wrongIssuerURL: 1, + correctIssuerURL: 1, + tokenEndpointURL: 1, + ] + ) + } + + // MARK: 35 - Proactive Token Refresh Within Window (Gap 5) + + func configureOAuthProactiveTokenRefreshWithinWindow() -> OAuthScenarioContext { + let testEndpoint = URL(string: "https://localhost:8080/test-proactive-refresh")! + let resourceMetadataURL = URL( + string: + "https://localhost:8080/.well-known/oauth-protected-resource/test-proactive-refresh" + )! + let asMetadataURL = URL( + string: "https://localhost:8080/.well-known/oauth-authorization-server/auth")! + let tokenEndpointURL = URL(string: "https://localhost:8080/oauth/token")! + let firstResponseData = MockResponses.jsonRPCResult(id: 96) + let secondResponseData = MockResponses.jsonRPCResult(id: 97) + + requestHandler = MockResponses.routingHandler(routes: [ + testEndpoint: { request in + let auth = request.value(forHTTPHeaderField: "Authorization") + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + if auth == nil { + return try await MockResponses.bearerChallenge( + resourceMetadataURL: resourceMetadataURL)(request) + } + if body.contains("\"id\":96") { + return try await MockResponses.jsonSuccess(body: firstResponseData)( + request) + } + // Second request: must be sent with the proactively refreshed token + if auth == "Bearer refreshed-token" { + return try await MockResponses.jsonSuccess(body: secondResponseData)( + request) + } + throw MockResponses.mockError( + "Expected refreshed token on second request, got: \(String(describing: auth))" + ) + }, + resourceMetadataURL: MockResponses.resourceMetadata( + authorizationServers: ["https://localhost:8080/auth"] + ), + asMetadataURL: MockResponses.asMetadata( + issuer: "https://localhost:8080/auth", + tokenEndpoint: "https://localhost:8080/oauth/token" + ), + tokenEndpointURL: { request in + let body = + String(data: readRequestBody(request) ?? Data(), encoding: .utf8) ?? "" + if body.contains("grant_type=client_credentials") { + // Initial token: expires in 300s (within 400s proactive window) + return try await MockResponses.tokenSuccess( + accessToken: "initial-token", + expiresIn: 300, + refreshToken: "rt-1" + )(request) + } + if body.contains("grant_type=refresh_token") { + #expect(body.contains("refresh_token=rt-1")) + return try await MockResponses.tokenSuccess( + accessToken: "refreshed-token", + expiresIn: 3600 + )(request) + } + throw MockResponses.mockError("Unexpected grant type in token request") + }, + ]) + + return OAuthScenarioContext( + testEndpoint: testEndpoint, + // proactiveRefreshWindowSeconds = 400 > expires_in = 300 + oauthConfiguration: .init( + authentication: .none(clientID: "test-client"), + proactiveRefreshWindowSeconds: 400 + ), + messageData: #"{"jsonrpc":"2.0","method":"ping","id":96}"#.data(using: .utf8)!, + expectedResponseData: firstResponseData, + expectedCallCounts: [ + testEndpoint: 3, // initial 401 + first send retry + second send + tokenEndpointURL: 2, // initial client_credentials + refresh_token + ], + secondMessageData: #"{"jsonrpc":"2.0","method":"ping","id":97}"#.data( + using: .utf8)!, + secondExpectedResponseData: secondResponseData + ) + } + + } + + #endif // !canImport(FoundationNetworking) + +#endif // swift(>=6.1) diff --git a/Tests/MCPTests/OAuthTokenEndpointClientTests.swift b/Tests/MCPTests/OAuthTokenEndpointClientTests.swift new file mode 100644 index 00000000..be9704c6 --- /dev/null +++ b/Tests/MCPTests/OAuthTokenEndpointClientTests.swift @@ -0,0 +1,194 @@ +@preconcurrency import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +#if swift(>=6.1) && !os(Linux) + + @Suite("OAuthTokenEndpointClient", .serialized) + struct OAuthTokenEndpointClientTests { + + let client = OAuthTokenEndpointClient(urlValidator: OAuthURLValidator()) + let tokenEndpoint = URL(string: "https://auth.example.com/token")! + + func successBody( + accessToken: String = "access-token", + tokenType: String = "Bearer", + scope: String? = nil, + refreshToken: String? = nil + ) throws -> Data { + var dict: [String: Any] = [ + "access_token": accessToken, + "token_type": tokenType, + "expires_in": 3600, + ] + if let scope { dict["scope"] = scope } + if let refreshToken { dict["refresh_token"] = refreshToken } + return try JSONSerialization.data(withJSONObject: dict) + } + + // MARK: - Success + + @Test("Returns decoded token response on success") + func testRequestReturnsToken() async throws { + let body = try successBody() + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 200, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + var params = ["grant_type": "client_credentials"] + let result = try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + let expected = OAuthTokenResponse( + accessToken: "access-token", tokenType: "Bearer", + expiresIn: 3600, scope: nil, refreshToken: nil) + #expect(result == expected) + } + + @Test("Parses optional scope and refresh token") + func testRequestParsesOptionalFields() async throws { + let body = try successBody(scope: "read write", refreshToken: "refresh-xyz") + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 200, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + var params = ["grant_type": "client_credentials"] + let result = try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + let expected = OAuthTokenResponse( + accessToken: "access-token", tokenType: "Bearer", + expiresIn: 3600, scope: "read write", refreshToken: "refresh-xyz") + #expect(result == expected) + } + + // MARK: - Error Responses + + @Test("Throws for non-2xx status") + func testRequestThrowsForNon2xx() async throws { + let errorBody = try JSONSerialization.data(withJSONObject: ["error": "invalid_client"]) + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 401, + httpVersion: nil, headerFields: nil)! + return (response, errorBody) + } + + var params = ["grant_type": "client_credentials"] + await #expect(throws: OAuthAuthorizationError.self) { + try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + } + } + + @Test("Throws for empty access_token") + func testRequestThrowsForEmptyAccessToken() async throws { + let body = try successBody(accessToken: "") + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 200, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + var params = ["grant_type": "client_credentials"] + await #expect(throws: OAuthAuthorizationError.self) { + try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + } + } + + @Test("Throws for non-Bearer token_type") + func testRequestThrowsForNonBearerTokenType() async throws { + let body = try successBody(tokenType: "MAC") + let (session, key) = makeIsolatedSession() + await IsolatedMockURLProtocol.setHandler(key: key) { _ in + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 200, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + var params = ["grant_type": "client_credentials"] + await #expect(throws: OAuthAuthorizationError.self) { + try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + } + } + + // MARK: - Form Encoding + + @Test("Sends parameters as form-encoded POST body") + func testRequestSendsFormEncodedBody() async throws { + let body = try successBody() + let (session, key) = makeIsolatedSession() + actor RequestCapture { var value: URLRequest?; func set(_ r: URLRequest) { value = r } } + let capture = RequestCapture() + await IsolatedMockURLProtocol.setHandler(key: key) { request in + await capture.set(request) + let response = HTTPURLResponse( + url: self.tokenEndpoint, statusCode: 200, + httpVersion: nil, headerFields: nil)! + return (response, body) + } + + var params = ["grant_type": "client_credentials", "resource": "https://api.example.com"] + _ = try await client.request( + parameters: ¶ms, + endpoint: tokenEndpoint, + authentication: .none(clientID: "client-id"), + session: session + ) + + let capturedRequest = await capture.value + let bodyData: Data = { + if let data = capturedRequest?.httpBody { return data } + guard let stream = capturedRequest?.httpBodyStream else { return Data() } + stream.open(); defer { stream.close() } + var data = Data() + let buf = UnsafeMutablePointer.allocate(capacity: 4096) + defer { buf.deallocate() } + while stream.hasBytesAvailable { data.append(buf, count: stream.read(buf, maxLength: 4096)) } + return data + }() + let bodyString = String(data: bodyData, encoding: .utf8) ?? "" + #expect(bodyString.contains("grant_type=client_credentials")) + #expect(bodyString.contains("resource=")) + #expect(capturedRequest?.httpMethod == "POST") + } + } + +#endif diff --git a/Tests/MCPTests/OAuthURLValidatorTests.swift b/Tests/MCPTests/OAuthURLValidatorTests.swift new file mode 100644 index 00000000..0071c042 --- /dev/null +++ b/Tests/MCPTests/OAuthURLValidatorTests.swift @@ -0,0 +1,141 @@ +import Foundation +import Testing + +@testable import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +@Suite("OAuthURLValidator") +struct OAuthURLValidatorTests { + + // MARK: - validateHTTPSOrLoopback + + @Test("Accepts HTTPS URL") + func testValidateHTTPSOrLoopbackAcceptsHTTPS() throws { + try OAuthURLValidator().validateHTTPSOrLoopback( + URL(string: "https://example.com/mcp")!, context: "test") + } + + @Test("Accepts loopback HTTP") + func testValidateHTTPSOrLoopbackAcceptsLoopback() throws { + try OAuthURLValidator().validateHTTPSOrLoopback( + URL(string: "http://localhost:8080/mcp")!, context: "test") + try OAuthURLValidator().validateHTTPSOrLoopback( + URL(string: "http://127.0.0.1:9000/mcp")!, context: "test") + } + + @Test("Rejects remote HTTP") + func testValidateHTTPSOrLoopbackRejectsRemoteHTTP() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateHTTPSOrLoopback( + URL(string: "http://example.com/mcp")!, context: "test") + } + } + + @Test("Rejects URL with fragment") + func testValidateHTTPSOrLoopbackRejectsFragment() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateHTTPSOrLoopback( + URL(string: "https://example.com/mcp#frag")!, context: "test") + } + } + + // MARK: - validateAuthorizationServer + + @Test("Accepts HTTPS authorization server") + func testValidateAuthorizationServerAcceptsHTTPS() throws { + try OAuthURLValidator().validateAuthorizationServer( + URL(string: "https://auth.example.com")!, context: "test") + } + + @Test("Rejects HTTP authorization server by default") + func testValidateAuthorizationServerRejectsHTTP() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateAuthorizationServer( + URL(string: "http://auth.example.com")!, context: "test") + } + } + + @Test("Accepts loopback HTTP when flag is set") + func testValidateAuthorizationServerAcceptsLoopbackWhenAllowed() throws { + let v = OAuthURLValidator(allowLoopbackHTTPForAuthorizationServer: true) + try v.validateAuthorizationServer( + URL(string: "http://localhost:8080")!, context: "test") + } + + @Test("Rejects loopback HTTP when flag is not set") + func testValidateAuthorizationServerRejectsLoopbackWhenNotAllowed() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateAuthorizationServer( + URL(string: "http://localhost:8080")!, context: "test") + } + } + + // MARK: - validateRedirectURI + + @Test("Accepts HTTPS redirect URI") + func testValidateRedirectURIAcceptsHTTPS() throws { + try OAuthURLValidator().validateRedirectURI( + URL(string: "https://app.example.com/callback")!) + } + + @Test("Accepts loopback HTTP redirect URI") + func testValidateRedirectURIAcceptsLoopback() throws { + try OAuthURLValidator().validateRedirectURI( + URL(string: "http://localhost:8080/callback")!) + } + + @Test("Rejects remote HTTP redirect URI") + func testValidateRedirectURIRejectsRemoteHTTP() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateRedirectURI( + URL(string: "http://app.example.com/callback")!) + } + } + + @Test("Rejects redirect URI with fragment") + func testValidateRedirectURIRejectsFragment() { + #expect(throws: OAuthAuthorizationError.self) { + try OAuthURLValidator().validateRedirectURI( + URL(string: "https://app.example.com/callback#section")!) + } + } + + // MARK: - isPrivateIPHost + + @Test("Identifies private IPv4 ranges") + func testIsPrivateIPHostIPv4() { + let v = OAuthURLValidator() + #expect(v.isPrivateIPHost("10.0.0.1")) + #expect(v.isPrivateIPHost("172.16.0.1")) + #expect(v.isPrivateIPHost("172.31.255.255")) + #expect(v.isPrivateIPHost("192.168.1.1")) + #expect(v.isPrivateIPHost("169.254.169.254")) + #expect(v.isPrivateIPHost("100.64.0.1")) + } + + @Test("Does not block public IPv4 addresses") + func testIsPrivateIPHostPublicIPv4() { + let v = OAuthURLValidator() + #expect(!v.isPrivateIPHost("1.2.3.4")) + #expect(!v.isPrivateIPHost("8.8.8.8")) + #expect(!v.isPrivateIPHost("203.0.113.1")) + } + + @Test("Identifies private IPv6 ULA and link-local addresses") + func testIsPrivateIPHostIPv6() { + let v = OAuthURLValidator() + #expect(v.isPrivateIPHost("fc00::1")) + #expect(v.isPrivateIPHost("fd12:3456::1")) + #expect(v.isPrivateIPHost("fe80::1")) + } + + @Test("Does not block public hostnames") + func testIsPrivateIPHostPublicHostname() { + let v = OAuthURLValidator() + #expect(!v.isPrivateIPHost("example.com")) + #expect(!v.isPrivateIPHost("localhost")) + } +} diff --git a/conformance-baseline.yml b/conformance-baseline.yml index 94480f9b..7127a6b9 100644 --- a/conformance-baseline.yml +++ b/conformance-baseline.yml @@ -1,19 +1 @@ -client: - - auth/metadata-default - - auth/metadata-var1 - - auth/metadata-var2 - - auth/metadata-var3 - - auth/basic-cimd - - auth/scope-from-www-authenticate - - auth/scope-from-scopes-supported - - auth/scope-omitted-when-undefined - - auth/scope-step-up - - auth/scope-retry-limit - - auth/token-endpoint-auth-basic - - auth/token-endpoint-auth-post - - auth/token-endpoint-auth-none - - auth/pre-registration - - auth/2025-03-26-oauth-metadata-backcompat - - auth/2025-03-26-oauth-endpoint-fallback - - auth/client-credentials-jwt - - auth/client-credentials-basic +client: []