diff --git a/crates/fetchkit/src/fetchers/default.rs b/crates/fetchkit/src/fetchers/default.rs index c1afb46..7da6eb2 100644 --- a/crates/fetchkit/src/fetchers/default.rs +++ b/crates/fetchkit/src/fetchers/default.rs @@ -292,7 +292,8 @@ impl Fetcher for DefaultFetcher { // THREAT[TM-DOS-001]: Read body with timeout and size limit // THREAT[TM-DOS-003]: Size limit also protects against compressed content bombs - let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await; + let (body, truncated) = + read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await?; let size = body.len() as u64; // Convert to string @@ -438,7 +439,8 @@ impl Fetcher for DefaultFetcher { } // Read raw body (no binary rejection for file saves) - let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await; + let (body, truncated) = + read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await?; let size = body.len() as u64; // Save through the FileSaver @@ -644,7 +646,7 @@ pub(crate) async fn read_body_with_timeout( response: reqwest::Response, timeout: Duration, max_size: usize, -) -> (Bytes, bool) { +) -> Result<(Bytes, bool), FetchError> { let mut body = Vec::new(); let mut stream = response.bytes_stream(); let deadline = tokio::time::Instant::now() + timeout; @@ -660,29 +662,31 @@ pub(crate) async fn read_body_with_timeout( let remaining = max_size.saturating_sub(body.len()); if remaining == 0 { warn!("Body size limit reached ({}), truncating", max_size); - return (Bytes::from(body), true); + return Ok((Bytes::from(body), true)); } if bytes.len() > remaining { body.extend_from_slice(&bytes[..remaining]); warn!("Body size limit reached ({}), truncating", max_size); - return (Bytes::from(body), true); + return Ok((Bytes::from(body), true)); } body.extend_from_slice(&bytes); } Some(Err(e)) => { error!("Error reading body chunk: {}", e); - let has_content = !body.is_empty(); - return (Bytes::from(body), has_content); + if body.is_empty() { + return Err(FetchError::from_reqwest(e)); + } + return Ok((Bytes::from(body), true)); } None => { // Stream complete - return (Bytes::from(body), false); + return Ok((Bytes::from(body), false)); } } } _ = timeout_future => { warn!("Body timeout reached, returning partial content"); - return (Bytes::from(body), true); + return Ok((Bytes::from(body), true)); } } } diff --git a/crates/fetchkit/src/fetchers/docs_site.rs b/crates/fetchkit/src/fetchers/docs_site.rs index 0db9208..2d08c46 100644 --- a/crates/fetchkit/src/fetchers/docs_site.rs +++ b/crates/fetchkit/src/fetchers/docs_site.rs @@ -239,7 +239,7 @@ async fn fetch_llms_txt_direct( } let max_body_size = options.max_body_size.unwrap_or(DEFAULT_MAX_BODY_SIZE); - let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await; + let (body, truncated) = read_body_with_timeout(response, BODY_TIMEOUT, max_body_size).await?; let size = body.len() as u64; let mut content = String::from_utf8_lossy(&body).to_string(); diff --git a/crates/fetchkit/tests/integration.rs b/crates/fetchkit/tests/integration.rs index 3ad141a..e9af093 100644 --- a/crates/fetchkit/tests/integration.rs +++ b/crates/fetchkit/tests/integration.rs @@ -5,6 +5,8 @@ use fetchkit::{ HttpMethod, LocalFileSaver, Tool, }; use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; use tower::Service; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -32,6 +34,25 @@ fn test_tool_with_save() -> Tool { .build() } +async fn spawn_malformed_chunked_server() -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + if let Ok((mut stream, _)) = listener.accept().await { + let mut buf = [0_u8; 1024]; + let _ = stream.read(&mut buf).await; + let _ = stream + .write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nTransfer-Encoding: chunked\r\n\r\nZZ\r\nboom\r\n0\r\n\r\n", + ) + .await; + } + }); + + format!("http://{addr}/") +} + #[tokio::test] async fn test_simple_get() { let mock_server = MockServer::start().await; @@ -55,6 +76,28 @@ async fn test_simple_get() { assert_eq!(resp.format, Some("raw".to_string())); } +#[tokio::test] +async fn test_malformed_chunked_body_returns_error() { + let req = FetchRequest::new(spawn_malformed_chunked_server().await); + let result = fetch_with_options(req, test_options()).await; + + assert!(matches!(result, Err(FetchError::RequestError(_)))); +} + +#[tokio::test] +async fn test_save_to_file_malformed_chunked_body_does_not_create_empty_file() { + let dir = tempfile::tempdir().unwrap(); + let saver = LocalFileSaver::new(Some(dir.path().to_path_buf())); + let req = + FetchRequest::new(spawn_malformed_chunked_server().await).save_to_file("malformed.txt"); + let result = test_tool_with_save() + .execute_with_saver(req, Some(&saver)) + .await; + + assert!(matches!(result, Err(FetchError::RequestError(_)))); + assert!(!dir.path().join("malformed.txt").exists()); +} + #[tokio::test] async fn test_head_request() { let mock_server = MockServer::start().await;