diff --git a/Cargo.lock b/Cargo.lock index a169090..f04e874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -590,6 +590,38 @@ dependencies = [ "x509-parser 0.18.1", ] +[[package]] +name = "attestation" +version = "0.0.1" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" +dependencies = [ + "anyhow", + "az-tdx-vtpm", + "base64 0.22.1", + "configfs-tsm", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", + "hex", + "http", + "num-bigint", + "once_cell", + "openssl", + "parity-scale-codec", + "pem-rfc7468", + "rand_core 0.6.4", + "reqwest", + "rustls-webpki", + "serde", + "serde_json", + "tdx-quote", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tracing", + "tss-esapi", + "x509-parser 0.18.1", +] + [[package]] name = "attestation-provider-server" version = "0.1.0" @@ -612,7 +644,7 @@ version = "0.0.1" dependencies = [ "alloy-rpc-client", "alloy-transport-http", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", "bytes", "futures-util", "http", @@ -638,10 +670,10 @@ dependencies = [ [[package]] name = "attested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "anyhow", - "attestation", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "ra-tls", "rcgen 0.14.7", "rustls", @@ -659,8 +691,8 @@ name = "attested-tls-proxy" version = "1.1.1" dependencies = [ "anyhow", - "attestation", - "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate)", + "attestation 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", + "attested-tls 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier)", "axum", "bytes", "clap", @@ -1072,7 +1104,7 @@ dependencies = [ [[package]] name = "cc-eventlog" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "digest 0.10.7", @@ -1661,7 +1693,7 @@ dependencies = [ [[package]] name = "dstack-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", @@ -1687,7 +1719,7 @@ dependencies = [ [[package]] name = "dstack-types" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "parity-scale-codec", "serde", @@ -2976,7 +3008,7 @@ dependencies = [ [[package]] name = "nested-tls" version = "0.0.1" -source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-crate#5c109dba74d4f9de58b4b846f480599752dfb1f9" +source = "git+https://github.com/flashbots/attested-tls?branch=peg%2Fattested-tls-expose-cert-verifier#2e4273cd93670e705e789555c00d43ca9c1e4af2" dependencies = [ "rustls", "tokio", @@ -3673,7 +3705,7 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ra-tls" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "bon", @@ -4480,7 +4512,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "size-parser" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "serde", @@ -4671,7 +4703,7 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tdx-attest" version = "0.5.8" -source = "git+https://github.com/Dstack-TEE/dstack.git#4f602dddc0542cd34da031c90ac0b3a560f316ed" +source = "git+https://github.com/Dstack-TEE/dstack.git#f87c97728ad222a3f3553cf0fb756830f7634eb6" dependencies = [ "anyhow", "cc-eventlog", diff --git a/Cargo.toml b/Cargo.toml index c285277..17b5749 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] -attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate" } +attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier" } tokio = { version = "1.50.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false } x509-parser = { version = "0.18.0", features = ["verify"] } @@ -47,7 +47,7 @@ pin-project-lite = "0.2.16" [dev-dependencies] tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } -attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-crate", features = ["mock"] } +attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/attested-tls-expose-cert-verifier", features = ["mock"] } tokio = { version = "1.48.0", features = ["full"] } jsonrpsee = { version = "0.26.0", features = ["server"] } diff --git a/src/lib.rs b/src/lib.rs index 88aa200..e7e012a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod http_version; #[cfg(test)] mod test_helpers; -use attestation::{AttestationError, AttestationVerifier}; +use attestation::{AttestationError, AttestationExchangeMessage, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier, AttestedTlsError}; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -36,6 +36,12 @@ use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; +/// The header name for giving attestation type +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + /// The header name for giving the forwarded for IP static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); @@ -399,8 +405,31 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted connection"); + // Get attestation from the remote certificate from the inner session, if present. + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => remote_cert_chain + .first() + .and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) + { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!( + "Failed to extract remote attestation from inner-session certificate: {err}" + ); + None + } + } + }), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn handle_inner_connection( @@ -410,8 +439,26 @@ impl ProxyServer { ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted inner-only connection"); + // Get attestation from the remote certificate, if present + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + match server_connection.peer_certificates() { + Some(remote_cert_chain) => remote_cert_chain.first().and_then(|cert| { + match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) { + Ok(attestation) => Some(attestation), + Err(err) => { + warn!("Failed to extract remote attestation from certificate: {err}"); + None + } + } + }), + None => None, + } + }; + let http_version = HttpVersion::from_negotiated_protocol_server(&tls_stream); - Self::serve_tls_stream(tls_stream, http_version, target, client_addr).await + Self::serve_tls_stream(tls_stream, http_version, target, client_addr, attestation).await } async fn serve_tls_stream( @@ -419,10 +466,25 @@ impl ProxyServer { http_version: HttpVersion, target: String, client_addr: SocketAddr, + attestation: Option, ) -> Result<(), ProxyError> where IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { + let (remote_attestation_type, measurements) = match attestation { + Some(attestation) => ( + Some(attestation.attestation_type), + match attestation.get_measurements() { + Ok(measurements) => measurements, + Err(err) => { + warn!("Failed to extract measurements from peer attestation: {err}"); + None + } + }, + ), + None => (None, None), + }; + // Setup a request handler let service = service_fn(move |mut req| { debug!("[proxy-server] Handling request {req:?}"); @@ -447,6 +509,30 @@ impl ProxyServer { update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // If we have measurements, from the remote peer, add them to the request header + let measurements = measurements.clone(); + + if let Some(measurements) = measurements { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + if let Some(remote_attestation_type) = remote_attestation_type { + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + } + let target = target.clone(); async move { match Self::handle_http_request(req, target).await { @@ -648,7 +734,7 @@ impl ProxyClient { let mut first = true; let mut ready_tx = Some(ready_tx); 'reconnect: loop { - let (mut sender, conn) = + let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) .await @@ -678,6 +764,9 @@ impl ProxyClient { let (conn_done_tx, mut conn_done_rx) = tokio::sync::watch::channel::>(None); + let remote_attestation_type = attestation.attestation_type; + let measurements = attestation.get_measurements().ok().flatten(); + tokio::spawn(async move { let res = conn.await; let _ = conn_done_tx.send(res.err()); @@ -690,8 +779,26 @@ impl ProxyClient { debug!("[proxy-client] Read incoming request from source client: {req:?}"); // Attempt to forward it to the proxy server let (response, should_reconnect) = match sender.send_request(req).await { - Ok(resp) => { + Ok(mut resp) => { debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { @@ -799,7 +906,7 @@ impl ProxyClient { target: &str, nesting_tls_connector: &NestingTlsConnector, should_bail: bool, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -828,15 +935,28 @@ impl ProxyClient { async fn setup_connection( nesting_tls_connector: &NestingTlsConnector, target: &str, - ) -> Result<(HttpSender, HttpConnection), ProxyError> { + ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; let tls_stream = nesting_tls_connector .connect(domain, outbound_stream) .await?; + debug!("[proxy-client] Connected to proxy server"); + let attestation = { + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; + + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + )? + }; + // The attestation exchange is now complete - setup an HTTP client let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); @@ -860,8 +980,7 @@ impl ProxyClient { } }; - // Return the HTTP client, as well as remote measurements - Ok((sender, conn)) + Ok((sender, conn, attestation)) } // Handle a request from the source client to the proxy server @@ -1056,6 +1175,11 @@ where #[cfg(test)] mod tests { use attestation::{AttestationType, measurements::MeasurementPolicy}; + use std::collections::HashMap; + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; use tokio_rustls::TlsConnector; use super::*; @@ -1064,6 +1188,43 @@ mod tests { generate_tls_config_with_client_auth, init_tracing, }; + fn expected_mock_measurements() -> HashMap { + let zero_measurement = "0".repeat(96); + HashMap::from([ + ("0".to_string(), zero_measurement.clone()), + ("1".to_string(), zero_measurement.clone()), + ("2".to_string(), zero_measurement.clone()), + ("3".to_string(), zero_measurement.clone()), + ("4".to_string(), zero_measurement), + ]) + } + + fn assert_mock_measurements(body: &str) { + let parsed: HashMap = serde_json::from_str(body).unwrap(); + assert_eq!(parsed, expected_mock_measurements()); + } + + fn assert_mock_measurements_header(headers: &http::HeaderMap) { + let body = headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap(); + assert_mock_measurements(body); + } + + fn assert_attestation_type_header(headers: &http::HeaderMap, expected: &str) { + assert_eq!( + headers + .get(ATTESTATION_TYPE_HEADER) + .and_then(|v| v.to_str().ok()), + Some(expected) + ); + } + + fn assert_no_measurements_header(headers: &http::HeaderMap) { + assert!(headers.get(MEASUREMENT_HEADER).is_none()); + } + #[test] fn proxy_alpn_protocols_prefer_http2() { let mut protocols = Vec::new(); @@ -1230,7 +1391,7 @@ mod tests { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - let (sender, conn) = ProxyClient::setup_connection( + let (sender, conn, _attestation) = ProxyClient::setup_connection( &nesting_tls_connector, &format!("localhost:{}", proxy_addr.port()), ) @@ -1294,6 +1455,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -1362,8 +1526,11 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); - assert_eq!(res_body, "No measurements"); + assert_mock_measurements(&res_body); } // Server has no attestation, client has mock DCAP but no client auth @@ -1423,7 +1590,11 @@ mod tests { .await .unwrap(); - let _res_body = res.text().await.unwrap(); + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); } // Server has mock DCAP, client has mock DCAP and client auth @@ -1490,12 +1661,16 @@ mod tests { let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); - assert_eq!(res.text().await.unwrap(), "No measurements"); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); } // Server has mock DCAP, client no attestation - just get the server certificate @@ -1692,6 +1867,7 @@ mod tests { // This is used to trigger a dropped connection to the proxy server let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); tokio::spawn(async move { let connection_handle = proxy_server.accept().await.unwrap(); @@ -1703,6 +1879,7 @@ mod tests { // Now accept another connection proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); }); let proxy_client = ProxyClient::new_with_tls_config( @@ -1723,22 +1900,150 @@ mod tests { proxy_client.accept().await.unwrap(); }); - let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) + let initial_response = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(initial_response.headers(), "dcap-tdx"); + assert_mock_measurements_header(initial_response.headers()); // Now break the connection connection_breaker_tx.send(()).unwrap(); + reconnected_rx.await.unwrap(); // Make another request let res = reqwest::get(format!("http://{}", proxy_client_addr)) .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } + #[tokio::test(flavor = "multi_thread")] + async fn http_proxy_does_not_retry_failed_request() { + init_tracing(); + + let request_count = Arc::new(AtomicUsize::new(0)); + let request_seen = Arc::new(tokio::sync::Notify::new()); + let (release_tx, release_rx) = tokio::sync::watch::channel(false); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target_addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/", + axum::routing::get({ + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let release_rx = release_rx.clone(); + + move || { + let request_count = request_count.clone(); + let request_seen = request_seen.clone(); + let mut release_rx = release_rx.clone(); + + async move { + request_count.fetch_add(1, Ordering::SeqCst); + request_seen.notify_waiters(); + + if !*release_rx.borrow() { + release_rx.changed().await.unwrap(); + } + + "ok" + } + } + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new( + Some(OuterTlsConfig { + listen_addr: "127.0.0.1:0", + tls: OuterTlsMode::Preconfigured { + server_config, + certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(), + }, + }), + Some("127.0.0.1:0"), + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + let (reconnected_tx, reconnected_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + proxy_server.accept().await.unwrap(); + let _ = reconnected_tx.send(()); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let request_url = format!("http://{}", proxy_client_addr); + let failed_request = tokio::spawn(async move { reqwest::get(request_url).await.unwrap() }); + + loop { + if request_count.load(Ordering::SeqCst) > 0 { + break; + } + + request_seen.notified().await; + } + + connection_breaker_tx.send(()).unwrap(); + release_tx.send(true).unwrap(); + + let failed_response = failed_request.await.unwrap(); + assert_eq!(failed_response.status(), hyper::StatusCode::BAD_GATEWAY); + assert_eq!(request_count.load(Ordering::SeqCst), 1); + + reconnected_rx.await.unwrap(); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "ok"); + assert_eq!(request_count.load(Ordering::SeqCst), 2); + } + // Use HTTP 1.1 #[tokio::test(flavor = "multi_thread")] async fn http_proxy_with_http1() { @@ -1794,6 +2099,9 @@ mod tests { .await .unwrap(); + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 431c5f8..b8509c0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -12,6 +12,8 @@ use tokio_rustls::rustls::{ }; use tracing_subscriber::{EnvFilter, fmt}; +use crate::MEASUREMENT_HEADER; + static INIT: Once = Once::new(); /// Helper to generate a self-signed certificate for testing with a DNS subject name @@ -127,13 +129,12 @@ pub async fn example_http_service() -> SocketAddr { addr } -async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse { - // headers - // .get(MEASUREMENT_HEADER) - // .and_then(|v| v.to_str().ok()) - // .unwrap_or("No measurements") - // .to_string() - "No measurements".to_string() +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() } pub fn init_tracing() {