Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ async def _initialize(self) -> None: # pragma: no cover
"""Load stored tokens and client info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = await self.context.storage.get_client_info()

if self.context.current_tokens and self.context.current_tokens.expires_in is not None:
self.context.update_token_expiry(self.context.current_tokens)

self._initialized = True

def _add_auth_header(self, request: httpx.Request) -> None:
Expand Down
105 changes: 105 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def valid_tokens():
)


@pytest.fixture
def expired_tokens():
return OAuthToken(
access_token="test_access_token",
token_type="Bearer",
expires_in=-100, # Expired 100 seconds ago
refresh_token="test_refresh_token",
scope="read write",
)


@pytest.fixture
def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage):
async def redirect_handler(url: str) -> None:
Expand Down Expand Up @@ -259,6 +270,100 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
assert context.token_expiry_time is None


class TestTokenInitialization:
"""Test token loading from storage during initialization."""

@pytest.mark.anyio
async def test_initialize_sets_token_expiry_from_stored_tokens(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Test _initialize() sets token_expiry_time when loading tokens from storage."""
context = oauth_provider.context
await context.storage.set_tokens(valid_tokens)

# Before initialization
assert oauth_provider._initialized is False
assert context.current_tokens is None
assert context.token_expiry_time is None

# Trigger initialization by starting auth flow
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request calls _initialize()
request = await auth_flow.__anext__()

# After first request, verify tokens were loaded
assert oauth_provider._initialized is True
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "test_access_token"

# token_expiry_time should be set by update_token_expiry()
assert oauth_provider.context.token_expiry_time is not None

# Verify token is considered valid
assert oauth_provider.context.is_token_valid() is True

# Request should have auth header added
assert request.headers["Authorization"] == "Bearer test_access_token"

# Complete the flow
response = httpx.Response(200, request=request)
try:
await auth_flow.asend(response)
except StopAsyncIteration:
pass

@pytest.mark.anyio
async def test_initialize_with_expired_tokens_detects_expiry(
self, oauth_provider: OAuthClientProvider, expired_tokens: OAuthToken
):
"""Test that expired tokens loaded from storage are detected as invalid."""
context = oauth_provider.context
await context.storage.set_tokens(expired_tokens)
await context.storage.set_client_info(
OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
)

# First request
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
auth_flow = oauth_provider.async_auth_flow(test_request)

# This should trigger a refresh attempt, not the original request
refresh_request = await auth_flow.__anext__()

# Verify tokens were loaded
assert context.current_tokens is not None

# token_expiry_time should be set by update_token_expiry()
assert context.token_expiry_time is not None

# Token should be detected as invalid (expired)
assert context.is_token_valid() is False

# Should be able to refresh
assert context.can_refresh_token() is True

# Complete the flow
refresh_response = httpx.Response(
200,
content=b'{"access_token": "new_token", "token_type": "Bearer", "expires_in": 3600}',
request=refresh_request,
)
try:
original_request = await auth_flow.asend(refresh_response)
# Should retry original request with new token
assert original_request.headers["Authorization"] == "Bearer new_token"
final_response = httpx.Response(200, request=original_request)
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass


class TestOAuthFlow:
"""Test OAuth flow methods."""

Expand Down