@@ -2376,56 +2376,153 @@ The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provide
237623763 . ** Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
237723774 . ** Use Access Token** to call protected MCP server tools
23782378
2379+ ** Using the Access Token with MCP Server:**
2380+
2381+ 1 . Once you have obtained the access token, you can use it to authenticate requests to the MCP server
2382+ 2 . The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
2383+
2384+ ** Handling Token Expiration and Refresh:**
2385+
2386+ Access tokens have a limited lifetime and will expire. When tokens expire:
2387+
2388+ - ** Check Token Expiration** : Use the ` expires_in ` field to determine when the token expires
2389+ - ** Refresh Flow** : When expired, repeat the token exchange flow with a fresh ID token from your IdP
2390+ - ** Automatic Refresh** : Implement automatic token refresh before expiration (recommended for production)
2391+ - ** Error Handling** : Catch authentication errors and retry with refreshed tokens
2392+
2393+ ** Important Notes:**
2394+
2395+ - ** ID Token Expiration** : If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
2396+ - ** Token Storage** : Store tokens securely and implement the ` TokenStorage ` interface to persist tokens between application restarts
2397+ - ** Scope Changes** : If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
2398+ - ** Security** : Never log or expose access tokens or ID tokens in production environments
2399+
23792400** Example Usage:**
23802401
2402+ <!-- snippet-source examples/snippets/clients/enterprise_managed_auth_client.py -->
23812403``` python
23822404import asyncio
2405+ from datetime import datetime, timedelta, timezone
2406+ from typing import Any
2407+
23832408import httpx
23842409from pydantic import AnyUrl
23852410
2411+ from mcp import ClientSession
2412+ from mcp.client.auth import OAuthTokenError, TokenStorage
23862413from mcp.client.auth.extensions import (
23872414 EnterpriseAuthOAuthClientProvider,
23882415 TokenExchangeParameters,
23892416)
2390- from mcp.shared.auth import OAuthClientMetadata
2391- from mcp.client.auth import TokenStorage
2417+ from mcp.client.sse import sse_client
2418+ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2419+ from mcp.types import CallToolResult
2420+
2421+
2422+ # Placeholder function for IdP authentication
2423+ async def get_id_token_from_idp () -> str :
2424+ """
2425+ Placeholder function to get ID token from your IdP.
2426+ In production, implement actual IdP authentication flow.
2427+ """
2428+ raise NotImplementedError (" Implement your IdP authentication flow here" )
2429+
23922430
23932431# Define token storage implementation
23942432class SimpleTokenStorage (TokenStorage ):
2395- def __init__ (self ):
2396- self ._tokens = None
2397- self ._client_info = None
2398-
2399- async def get_tokens (self ):
2433+ def __init__ (self ) -> None :
2434+ self ._tokens: OAuthToken | None = None
2435+ self ._client_info: OAuthClientInformationFull | None = None
2436+
2437+ async def get_tokens (self ) -> OAuthToken | None :
24002438 return self ._tokens
2401-
2402- async def set_tokens (self , tokens ) :
2439+
2440+ async def set_tokens (self , tokens : OAuthToken) -> None :
24032441 self ._tokens = tokens
2404-
2405- async def get_client_info (self ):
2442+
2443+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
24062444 return self ._client_info
2407-
2408- async def set_client_info (self , client_info ) :
2445+
2446+ async def set_client_info (self , client_info : OAuthClientInformationFull) -> None :
24092447 self ._client_info = client_info
24102448
2411- async def main ():
2449+
2450+ def is_token_expired (access_token : OAuthToken) -> bool :
2451+ """ Check if the access token has expired."""
2452+ if access_token.expires_in:
2453+ # Calculate expiration time
2454+ issued_at = datetime.now(timezone.utc)
2455+ expiration_time = issued_at + timedelta(seconds = access_token.expires_in)
2456+ return datetime.now(timezone.utc) >= expiration_time
2457+ return False
2458+
2459+
2460+ async def refresh_access_token (
2461+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2462+ client : httpx.AsyncClient,
2463+ id_token : str ,
2464+ ) -> OAuthToken:
2465+ """ Refresh the access token when it expires."""
2466+ try :
2467+ # Update token exchange parameters with fresh ID token
2468+ enterprise_auth.token_exchange_params.subject_token = id_token
2469+
2470+ # Re-exchange for new ID-JAG
2471+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2472+
2473+ # Get new access token
2474+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2475+ return access_token
2476+ except Exception as e:
2477+ print (f " Token refresh failed: { e} " )
2478+ # Re-authenticate with IdP if ID token is also expired
2479+ id_token = await get_id_token_from_idp()
2480+ return await refresh_access_token(enterprise_auth, client, id_token)
2481+
2482+
2483+ async def call_tool_with_retry (
2484+ session : ClientSession,
2485+ tool_name : str ,
2486+ arguments : dict[str , Any],
2487+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2488+ client : httpx.AsyncClient,
2489+ id_token : str ,
2490+ ) -> CallToolResult | None :
2491+ """ Call a tool with automatic retry on token expiration."""
2492+ max_retries = 1
2493+
2494+ for attempt in range (max_retries + 1 ):
2495+ try :
2496+ result = await session.call_tool(tool_name, arguments)
2497+ return result
2498+ except OAuthTokenError:
2499+ if attempt < max_retries:
2500+ print (" Token expired, refreshing..." )
2501+ # Refresh token and reconnect
2502+ _access_token = await refresh_access_token(enterprise_auth, client, id_token)
2503+ # Note: In production, you'd need to reconnect the session here
2504+ else :
2505+ raise
2506+ return None
2507+
2508+
2509+ async def main () -> None :
24122510 # Step 1: Get ID token from your IdP (example with Okta)
24132511 id_token = await get_id_token_from_idp() # Your IdP authentication
2414-
2512+
24152513 # Step 2: Configure token exchange parameters
24162514 token_exchange_params = TokenExchangeParameters.from_id_token(
24172515 id_token = id_token,
24182516 mcp_server_auth_issuer = " https://your-idp.com" , # IdP issuer URL
24192517 mcp_server_resource_id = " https://mcp-server.example.com" , # MCP server resource ID
24202518 scope = " mcp:tools mcp:resources" , # Optional scopes
24212519 )
2422-
2520+
24232521 # Step 3: Create enterprise auth provider
24242522 enterprise_auth = EnterpriseAuthOAuthClientProvider(
24252523 server_url = " https://mcp-server.example.com" ,
24262524 client_metadata = OAuthClientMetadata(
24272525 client_name = " Enterprise MCP Client" ,
2428- client_id = " your-client-id" ,
24292526 redirect_uris = [AnyUrl(" http://localhost:3000/callback" )],
24302527 grant_types = [" urn:ietf:params:oauth:grant-type:jwt-bearer" ],
24312528 response_types = [" token" ],
@@ -2434,23 +2531,85 @@ async def main():
24342531 idp_token_endpoint = " https://your-idp.com/oauth2/v1/token" ,
24352532 token_exchange_params = token_exchange_params,
24362533 )
2437-
2438- # Step 4: Perform token exchange and get access token
2534+
24392535 async with httpx.AsyncClient() as client:
2440- # Exchange ID token for ID-JAG
2536+ # Step 4: Exchange ID token for ID-JAG
24412537 id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
24422538 print (f " Obtained ID-JAG: { id_jag[:50 ]} ... " )
2443-
2444- # Exchange ID-JAG for access token
2445- access_token = await enterprise_auth.exchange_id_jag_for_access_token(
2446- client, id_jag
2447- )
2539+
2540+ # Step 5: Exchange ID-JAG for access token
2541+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
24482542 print (f " Access token obtained, expires in: { access_token.expires_in} s " )
24492543
2544+ # Step 6: Check if token is expired (for demonstration)
2545+ if is_token_expired(access_token):
2546+ print (" Token is expired, refreshing..." )
2547+ access_token = await refresh_access_token(enterprise_auth, client, id_token)
2548+
2549+ # Step 7: Use the access token to connect to MCP server
2550+ headers = {" Authorization" : f " Bearer { access_token.access_token} " }
2551+
2552+ async with sse_client(url = " https://mcp-server.example.com" , headers = headers) as (read, write):
2553+ async with ClientSession(read, write) as session:
2554+ await session.initialize()
2555+
2556+ # Call tools with automatic retry on token expiration
2557+ result = await call_tool_with_retry(
2558+ session, " enterprise_tool" , {" param" : " value" }, enterprise_auth, client, id_token
2559+ )
2560+ if result:
2561+ print (f " Tool result: { result.content} " )
2562+
2563+ # List available resources
2564+ resources = await session.list_resources()
2565+ for resource in resources.resources:
2566+ print (f " Resource: { resource.uri} " )
2567+
2568+
2569+ async def maintain_active_session (
2570+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2571+ mcp_server_url : str ,
2572+ ) -> None :
2573+ """ Maintain an active session with automatic token refresh."""
2574+ id_token_var = await get_id_token_from_idp()
2575+
2576+ async with httpx.AsyncClient() as client:
2577+ while True :
2578+ try :
2579+ # Update token exchange params with current ID token
2580+ enterprise_auth.token_exchange_params.subject_token = id_token_var
2581+
2582+ # Get access token
2583+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2584+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2585+
2586+ # Calculate refresh time (refresh before expiration)
2587+ refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
2588+
2589+ # Use the token for MCP operations
2590+ headers = {" Authorization" : f " Bearer { access_token.access_token} " }
2591+ async with sse_client(mcp_server_url, headers = headers) as (read, write):
2592+ async with ClientSession(read, write) as session:
2593+ await session.initialize()
2594+
2595+ # Perform operations...
2596+ # Schedule refresh before token expires
2597+ await asyncio.sleep(refresh_in)
2598+
2599+ except Exception as e:
2600+ print (f " Session error: { e} " )
2601+ # Re-authenticate with IdP
2602+ id_token_var = await get_id_token_from_idp()
2603+ await asyncio.sleep(5 ) # Wait before retry
2604+
2605+
24502606if __name__ == " __main__" :
24512607 asyncio.run(main())
24522608```
24532609
2610+ _ Full example: [ examples/snippets/clients/enterprise_managed_auth_client.py] ( https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py ) _
2611+ <!-- /snippet-source -->
2612+
24542613** Working with SAML Assertions:**
24552614
24562615If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
0 commit comments