Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,20 @@ impl AuthorizationManager {
Ok(false)
}

/// Use a caller-configured `reqwest::Client` for every OAuth HTTP operation,
/// preserving all of its settings (proxy, TLS, timeout, default headers).
///
/// The same client is reused for all requests, so its own redirect policy applies
/// and [`OAuthHttpRedirectPolicy::Stop`] is not enforced for token operations.
/// Callers needing strict no-redirect handling should pass a custom
/// [`OAuthHttpClient`] to [`AuthorizationManager::new_with_oauth_http_client`].
pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> {
self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?);
// One client for both modes: a built reqwest::Client can't be rebuilt as a
// no-redirect variant without dropping the caller's configuration.
self.http_client = Arc::new(ReqwestOAuthHttpClient {
follow_redirects: http_client.clone(),
stop_redirects: http_client,
});
self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Follow;
Ok(())
}
Expand Down Expand Up @@ -4871,6 +4883,149 @@ mod tests {
);
}

#[tokio::test]
async fn exchange_code_uses_client_configured_by_with_client() {
use axum::{Router, body::Body, http::Response, routing::post};

let received_header = Arc::new(std::sync::Mutex::new(None));
let received_header_clone = Arc::clone(&received_header);
let app = Router::new().route(
"/token",
post(move |headers: axum::http::HeaderMap| {
let received_header = Arc::clone(&received_header_clone);
async move {
*received_header.lock().unwrap() = headers
.get("x-custom-client")
.and_then(|value| value.to_str().ok())
.map(str::to_string);
Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Body::from(
r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#,
))
.unwrap()
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });

let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("http://{addr}/authorize"),
token_endpoint: format!("http://{addr}/token"),
..Default::default()
}))
.await;
let mut default_headers = reqwest::header::HeaderMap::new();
default_headers.insert("x-custom-client", "configured".parse().unwrap());
manager
.with_client(
reqwest::Client::builder()
.default_headers(default_headers)
.build()
.unwrap(),
)
.unwrap();
manager.configure_client(test_client_config()).unwrap();
let authorization_url = manager.get_authorization_url(&[]).await.unwrap();
let state = Url::parse(&authorization_url)
.unwrap()
.query_pairs()
.find(|(name, _)| name == "state")
.unwrap()
.1
.into_owned();

manager
.exchange_code_for_token("authorization-code", &state)
.await
.unwrap();

assert_eq!(
received_header.lock().unwrap().as_deref(),
Some("configured")
);
}

#[tokio::test]
async fn exchange_code_follows_redirects_with_with_client() {
use std::sync::atomic::{AtomicBool, Ordering};

use axum::{
Router,
body::Body,
http::{Response, StatusCode},
routing::post,
};

// The token endpoint replies with a 307 redirect; the with_client path reuses
// the caller's redirect-following client, so the request is expected to follow
// it to the final endpoint that returns the token.
let final_endpoint_hit = Arc::new(AtomicBool::new(false));
let final_endpoint_hit_clone = Arc::clone(&final_endpoint_hit);
let app = Router::new()
.route(
"/token",
post(|| async {
Response::builder()
.status(StatusCode::TEMPORARY_REDIRECT)
.header("location", "/token-final")
.body(Body::empty())
.unwrap()
}),
)
.route(
"/token-final",
post(move || {
let final_endpoint_hit = Arc::clone(&final_endpoint_hit_clone);
async move {
final_endpoint_hit.store(true, Ordering::SeqCst);
Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Body::from(
r#"{"access_token":"redirected-token","token_type":"Bearer","expires_in":3600}"#,
))
.unwrap()
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });

let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("http://{addr}/authorize"),
token_endpoint: format!("http://{addr}/token"),
..Default::default()
}))
.await;
manager
.with_client(reqwest::Client::builder().build().unwrap())
.unwrap();
manager.configure_client(test_client_config()).unwrap();
let authorization_url = manager.get_authorization_url(&[]).await.unwrap();
let state = Url::parse(&authorization_url)
.unwrap()
.query_pairs()
.find(|(name, _)| name == "state")
.unwrap()
.1
.into_owned();

manager
.exchange_code_for_token("authorization-code", &state)
.await
.unwrap();

assert!(
final_endpoint_hit.load(Ordering::SeqCst),
"with_client path should follow redirects on token exchange"
);
}

async fn start_token_server() -> (String, Arc<std::sync::Mutex<Option<String>>>) {
use axum::{Router, body::Body, http::Response, routing::post};
let captured: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
Expand Down