Skip to content

Commit 28bb315

Browse files
committed
- Added typing for request payload structures TokenExchangeRequestData and JWTBearerGrantRequestData.
- Added snippet file for adding code to the README.md file. - Added new section in README.md file to add information regarding: "how to use the access token once you get it" and "How does this work when the client ID is expired?".
1 parent 09c05aa commit 28bb315

File tree

4 files changed

+417
-52
lines changed

4 files changed

+417
-52
lines changed

README.md

Lines changed: 184 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,56 +2376,153 @@ The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provide
23762376
3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
23772377
4. **Use Access Token** to call protected MCP server tools
23782378

2379+
**Using the Access Token with MCP Server:**
2380+
2381+
1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server
2382+
2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
2383+
2384+
**Handling Token Expiration and Refresh:**
2385+
2386+
Access tokens have a limited lifetime and will expire. When tokens expire:
2387+
2388+
- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires
2389+
- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP
2390+
- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production)
2391+
- **Error Handling**: Catch authentication errors and retry with refreshed tokens
2392+
2393+
**Important Notes:**
2394+
2395+
- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
2396+
- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts
2397+
- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
2398+
- **Security**: Never log or expose access tokens or ID tokens in production environments
2399+
23792400
**Example Usage:**
23802401

2402+
<!-- snippet-source examples/snippets/clients/enterprise_managed_auth_client.py -->
23812403
```python
23822404
import asyncio
2405+
from datetime import datetime, timedelta, timezone
2406+
from typing import Any
2407+
23832408
import httpx
23842409
from pydantic import AnyUrl
23852410

2411+
from mcp import ClientSession
2412+
from mcp.client.auth import OAuthTokenError, TokenStorage
23862413
from mcp.client.auth.extensions import (
23872414
EnterpriseAuthOAuthClientProvider,
23882415
TokenExchangeParameters,
23892416
)
2390-
from mcp.shared.auth import OAuthClientMetadata
2391-
from mcp.client.auth import TokenStorage
2417+
from mcp.client.sse import sse_client
2418+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2419+
from mcp.types import CallToolResult
2420+
2421+
2422+
# Placeholder function for IdP authentication
2423+
async def get_id_token_from_idp() -> str:
2424+
"""
2425+
Placeholder function to get ID token from your IdP.
2426+
In production, implement actual IdP authentication flow.
2427+
"""
2428+
raise NotImplementedError("Implement your IdP authentication flow here")
2429+
23922430

23932431
# Define token storage implementation
23942432
class SimpleTokenStorage(TokenStorage):
2395-
def __init__(self):
2396-
self._tokens = None
2397-
self._client_info = None
2398-
2399-
async def get_tokens(self):
2433+
def __init__(self) -> None:
2434+
self._tokens: OAuthToken | None = None
2435+
self._client_info: OAuthClientInformationFull | None = None
2436+
2437+
async def get_tokens(self) -> OAuthToken | None:
24002438
return self._tokens
2401-
2402-
async def set_tokens(self, tokens):
2439+
2440+
async def set_tokens(self, tokens: OAuthToken) -> None:
24032441
self._tokens = tokens
2404-
2405-
async def get_client_info(self):
2442+
2443+
async def get_client_info(self) -> OAuthClientInformationFull | None:
24062444
return self._client_info
2407-
2408-
async def set_client_info(self, client_info):
2445+
2446+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
24092447
self._client_info = client_info
24102448

2411-
async def main():
2449+
2450+
def is_token_expired(access_token: OAuthToken) -> bool:
2451+
"""Check if the access token has expired."""
2452+
if access_token.expires_in:
2453+
# Calculate expiration time
2454+
issued_at = datetime.now(timezone.utc)
2455+
expiration_time = issued_at + timedelta(seconds=access_token.expires_in)
2456+
return datetime.now(timezone.utc) >= expiration_time
2457+
return False
2458+
2459+
2460+
async def refresh_access_token(
2461+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2462+
client: httpx.AsyncClient,
2463+
id_token: str,
2464+
) -> OAuthToken:
2465+
"""Refresh the access token when it expires."""
2466+
try:
2467+
# Update token exchange parameters with fresh ID token
2468+
enterprise_auth.token_exchange_params.subject_token = id_token
2469+
2470+
# Re-exchange for new ID-JAG
2471+
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2472+
2473+
# Get new access token
2474+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2475+
return access_token
2476+
except Exception as e:
2477+
print(f"Token refresh failed: {e}")
2478+
# Re-authenticate with IdP if ID token is also expired
2479+
id_token = await get_id_token_from_idp()
2480+
return await refresh_access_token(enterprise_auth, client, id_token)
2481+
2482+
2483+
async def call_tool_with_retry(
2484+
session: ClientSession,
2485+
tool_name: str,
2486+
arguments: dict[str, Any],
2487+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2488+
client: httpx.AsyncClient,
2489+
id_token: str,
2490+
) -> CallToolResult | None:
2491+
"""Call a tool with automatic retry on token expiration."""
2492+
max_retries = 1
2493+
2494+
for attempt in range(max_retries + 1):
2495+
try:
2496+
result = await session.call_tool(tool_name, arguments)
2497+
return result
2498+
except OAuthTokenError:
2499+
if attempt < max_retries:
2500+
print("Token expired, refreshing...")
2501+
# Refresh token and reconnect
2502+
_access_token = await refresh_access_token(enterprise_auth, client, id_token)
2503+
# Note: In production, you'd need to reconnect the session here
2504+
else:
2505+
raise
2506+
return None
2507+
2508+
2509+
async def main() -> None:
24122510
# Step 1: Get ID token from your IdP (example with Okta)
24132511
id_token = await get_id_token_from_idp() # Your IdP authentication
2414-
2512+
24152513
# Step 2: Configure token exchange parameters
24162514
token_exchange_params = TokenExchangeParameters.from_id_token(
24172515
id_token=id_token,
24182516
mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL
24192517
mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID
24202518
scope="mcp:tools mcp:resources", # Optional scopes
24212519
)
2422-
2520+
24232521
# Step 3: Create enterprise auth provider
24242522
enterprise_auth = EnterpriseAuthOAuthClientProvider(
24252523
server_url="https://mcp-server.example.com",
24262524
client_metadata=OAuthClientMetadata(
24272525
client_name="Enterprise MCP Client",
2428-
client_id="your-client-id",
24292526
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
24302527
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
24312528
response_types=["token"],
@@ -2434,23 +2531,85 @@ async def main():
24342531
idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
24352532
token_exchange_params=token_exchange_params,
24362533
)
2437-
2438-
# Step 4: Perform token exchange and get access token
2534+
24392535
async with httpx.AsyncClient() as client:
2440-
# Exchange ID token for ID-JAG
2536+
# Step 4: Exchange ID token for ID-JAG
24412537
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
24422538
print(f"Obtained ID-JAG: {id_jag[:50]}...")
2443-
2444-
# Exchange ID-JAG for access token
2445-
access_token = await enterprise_auth.exchange_id_jag_for_access_token(
2446-
client, id_jag
2447-
)
2539+
2540+
# Step 5: Exchange ID-JAG for access token
2541+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
24482542
print(f"Access token obtained, expires in: {access_token.expires_in}s")
24492543

2544+
# Step 6: Check if token is expired (for demonstration)
2545+
if is_token_expired(access_token):
2546+
print("Token is expired, refreshing...")
2547+
access_token = await refresh_access_token(enterprise_auth, client, id_token)
2548+
2549+
# Step 7: Use the access token to connect to MCP server
2550+
headers = {"Authorization": f"Bearer {access_token.access_token}"}
2551+
2552+
async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write):
2553+
async with ClientSession(read, write) as session:
2554+
await session.initialize()
2555+
2556+
# Call tools with automatic retry on token expiration
2557+
result = await call_tool_with_retry(
2558+
session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token
2559+
)
2560+
if result:
2561+
print(f"Tool result: {result.content}")
2562+
2563+
# List available resources
2564+
resources = await session.list_resources()
2565+
for resource in resources.resources:
2566+
print(f"Resource: {resource.uri}")
2567+
2568+
2569+
async def maintain_active_session(
2570+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2571+
mcp_server_url: str,
2572+
) -> None:
2573+
"""Maintain an active session with automatic token refresh."""
2574+
id_token_var = await get_id_token_from_idp()
2575+
2576+
async with httpx.AsyncClient() as client:
2577+
while True:
2578+
try:
2579+
# Update token exchange params with current ID token
2580+
enterprise_auth.token_exchange_params.subject_token = id_token_var
2581+
2582+
# Get access token
2583+
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2584+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2585+
2586+
# Calculate refresh time (refresh before expiration)
2587+
refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
2588+
2589+
# Use the token for MCP operations
2590+
headers = {"Authorization": f"Bearer {access_token.access_token}"}
2591+
async with sse_client(mcp_server_url, headers=headers) as (read, write):
2592+
async with ClientSession(read, write) as session:
2593+
await session.initialize()
2594+
2595+
# Perform operations...
2596+
# Schedule refresh before token expires
2597+
await asyncio.sleep(refresh_in)
2598+
2599+
except Exception as e:
2600+
print(f"Session error: {e}")
2601+
# Re-authenticate with IdP
2602+
id_token_var = await get_id_token_from_idp()
2603+
await asyncio.sleep(5) # Wait before retry
2604+
2605+
24502606
if __name__ == "__main__":
24512607
asyncio.run(main())
24522608
```
24532609

2610+
_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_
2611+
<!-- /snippet-source -->
2612+
24542613
**Working with SAML Assertions:**
24552614

24562615
If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:

0 commit comments

Comments
 (0)