Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1693,14 +1693,21 @@ impl AuthorizationManager {
for scope in refresh_scopes {
refresh_request = refresh_request.add_scope(Scope::new(scope));
}
let token_result = refresh_request
let mut token_result = refresh_request
.request_async(&OAuth2HttpClient {
client: self.http_client.as_ref(),
redirect_policy: self.refresh_redirect_policy,
})
.await
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;

// RFC 6749 section 6: issuing a new refresh token on refresh is optional.
// When the response omits one, keep the existing refresh token rather than
// dropping it. When a new one is present, the response value is used as-is.
if token_result.refresh_token().is_none() {
token_result.set_refresh_token(Some(refresh_token_value));
}

let granted_scopes: Vec<String> = match token_result.scopes() {
Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(),
None => self.current_scopes.read().await.clone(),
Expand Down Expand Up @@ -5554,4 +5561,101 @@ mod tests {
"scope should be absent when granted_scopes is empty, body: {body}"
);
}

#[tokio::test]
async fn refresh_token_preserves_existing_refresh_token_when_response_omits_it() {
use oauth2::TokenResponse;
// start_token_server returns a response without a refresh_token, matching
// an authorization server that does not rotate refresh tokens on refresh.
let (base_url, _captured) = start_token_server().await;

let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("{}/authorize", base_url),
token_endpoint: format!("{}/token", base_url),
..Default::default()
}))
.await;
manager.configure_client(test_client_config()).unwrap();

let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response_with_refresh(
"old-token",
"my-refresh-token",
)),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();

let result = manager.refresh_token().await.unwrap();
assert_eq!(
result.refresh_token().map(|t| t.secret().as_str()),
Some("my-refresh-token"),
"returned response should keep the previous refresh token"
);

let reloaded = manager
.credential_store
.load()
.await
.unwrap()
.unwrap()
.token_response
.unwrap();
assert_eq!(
reloaded.refresh_token().map(|t| t.secret().as_str()),
Some("my-refresh-token"),
"stored credentials should keep the previous refresh token"
);
}

#[tokio::test]
async fn refresh_token_replaces_refresh_token_when_response_includes_new_one() {
use axum::{Router, body::Body, http::Response, routing::post};
use oauth2::TokenResponse;

let app = Router::new().route(
"/token",
post(|| async {
Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Body::from(
r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600,"refresh_token":"rotated-refresh-token"}"#,
))
.unwrap()
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
let base_url = format!("http://{}", addr);

let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("{}/authorize", base_url),
token_endpoint: format!("{}/token", base_url),
..Default::default()
}))
.await;
manager.configure_client(test_client_config()).unwrap();

let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response_with_refresh(
"old-token",
"my-refresh-token",
)),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();

let result = manager.refresh_token().await.unwrap();
assert_eq!(
result.refresh_token().map(|t| t.secret().as_str()),
Some("rotated-refresh-token"),
"a rotated refresh token from the response should replace the old one"
);
}
}