diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 5ccc9bc1..e2a8541e 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -944,8 +944,20 @@ impl AuthorizationManager { Ok(false) } + /// Use a caller-configured `reqwest::Client` for every OAuth HTTP operation, + /// preserving all of its settings (proxy, TLS, timeout, default headers). + /// + /// The same client is reused for all requests, so its own redirect policy applies + /// and [`OAuthHttpRedirectPolicy::Stop`] is not enforced for token operations. + /// Callers needing strict no-redirect handling should pass a custom + /// [`OAuthHttpClient`] to [`AuthorizationManager::new_with_oauth_http_client`]. pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> { - self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?); + // One client for both modes: a built reqwest::Client can't be rebuilt as a + // no-redirect variant without dropping the caller's configuration. + self.http_client = Arc::new(ReqwestOAuthHttpClient { + follow_redirects: http_client.clone(), + stop_redirects: http_client, + }); self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Follow; Ok(()) } @@ -4871,6 +4883,149 @@ mod tests { ); } + #[tokio::test] + async fn exchange_code_uses_client_configured_by_with_client() { + use axum::{Router, body::Body, http::Response, routing::post}; + + let received_header = Arc::new(std::sync::Mutex::new(None)); + let received_header_clone = Arc::clone(&received_header); + let app = Router::new().route( + "/token", + post(move |headers: axum::http::HeaderMap| { + let received_header = Arc::clone(&received_header_clone); + async move { + *received_header.lock().unwrap() = headers + .get("x-custom-client") + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + Response::builder() + .status(200) + .header("content-type", "application/json") + .body(Body::from( + r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#, + )) + .unwrap() + } + }), + ); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let mut manager = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: format!("http://{addr}/authorize"), + token_endpoint: format!("http://{addr}/token"), + ..Default::default() + })) + .await; + let mut default_headers = reqwest::header::HeaderMap::new(); + default_headers.insert("x-custom-client", "configured".parse().unwrap()); + manager + .with_client( + reqwest::Client::builder() + .default_headers(default_headers) + .build() + .unwrap(), + ) + .unwrap(); + manager.configure_client(test_client_config()).unwrap(); + let authorization_url = manager.get_authorization_url(&[]).await.unwrap(); + let state = Url::parse(&authorization_url) + .unwrap() + .query_pairs() + .find(|(name, _)| name == "state") + .unwrap() + .1 + .into_owned(); + + manager + .exchange_code_for_token("authorization-code", &state) + .await + .unwrap(); + + assert_eq!( + received_header.lock().unwrap().as_deref(), + Some("configured") + ); + } + + #[tokio::test] + async fn exchange_code_follows_redirects_with_with_client() { + use std::sync::atomic::{AtomicBool, Ordering}; + + use axum::{ + Router, + body::Body, + http::{Response, StatusCode}, + routing::post, + }; + + // The token endpoint replies with a 307 redirect; the with_client path reuses + // the caller's redirect-following client, so the request is expected to follow + // it to the final endpoint that returns the token. + let final_endpoint_hit = Arc::new(AtomicBool::new(false)); + let final_endpoint_hit_clone = Arc::clone(&final_endpoint_hit); + let app = Router::new() + .route( + "/token", + post(|| async { + Response::builder() + .status(StatusCode::TEMPORARY_REDIRECT) + .header("location", "/token-final") + .body(Body::empty()) + .unwrap() + }), + ) + .route( + "/token-final", + post(move || { + let final_endpoint_hit = Arc::clone(&final_endpoint_hit_clone); + async move { + final_endpoint_hit.store(true, Ordering::SeqCst); + Response::builder() + .status(200) + .header("content-type", "application/json") + .body(Body::from( + r#"{"access_token":"redirected-token","token_type":"Bearer","expires_in":3600}"#, + )) + .unwrap() + } + }), + ); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let mut manager = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: format!("http://{addr}/authorize"), + token_endpoint: format!("http://{addr}/token"), + ..Default::default() + })) + .await; + manager + .with_client(reqwest::Client::builder().build().unwrap()) + .unwrap(); + manager.configure_client(test_client_config()).unwrap(); + let authorization_url = manager.get_authorization_url(&[]).await.unwrap(); + let state = Url::parse(&authorization_url) + .unwrap() + .query_pairs() + .find(|(name, _)| name == "state") + .unwrap() + .1 + .into_owned(); + + manager + .exchange_code_for_token("authorization-code", &state) + .await + .unwrap(); + + assert!( + final_endpoint_hit.load(Ordering::SeqCst), + "with_client path should follow redirects on token exchange" + ); + } + async fn start_token_server() -> (String, Arc>>) { use axum::{Router, body::Body, http::Response, routing::post}; let captured: Arc>> = Arc::new(std::sync::Mutex::new(None));