From 810783faab38f7b543e290ec6b62a6710e6d9b89 Mon Sep 17 00:00:00 2001 From: Caelan Date: Sun, 16 Feb 2025 11:54:10 -0700 Subject: [PATCH] Fix: rust download not propagating errors --- Cargo.lock | 1 + src/lib-rust/Cargo.toml | 3 ++ src/lib-rust/src/download.rs | 88 +++++++++++++++++++++++++++++++++--- 3 files changed, 85 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f8363708..5981e9f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2187,6 +2187,7 @@ dependencies = [ "serde_json", "sha1", "sha2", + "tempfile", "thiserror 2.0.11", "ts-rs", "ureq", diff --git a/src/lib-rust/Cargo.toml b/src/lib-rust/Cargo.toml index 317b4ff4..9d164b4b 100644 --- a/src/lib-rust/Cargo.toml +++ b/src/lib-rust/Cargo.toml @@ -26,6 +26,9 @@ features = ["async", "delta"] name = "velopack" path = "src/lib.rs" +[dev-dependencies] +tempfile.workspace = true + [dependencies] log.workspace = true ureq.workspace = true diff --git a/src/lib-rust/src/download.rs b/src/lib-rust/src/download.rs index a82f305f..d27094b8 100644 --- a/src/lib-rust/src/download.rs +++ b/src/lib-rust/src/download.rs @@ -21,16 +21,17 @@ where let mut last_progress = 0; - while let Ok(size) = reader.read(&mut buffer) { + loop { + let size = reader.read(&mut buffer)?; // Explicitly propagate errors if size == 0 { break; // End of stream } file.write_all(&buffer[..size])?; downloaded += size as u64; - if total_size.is_some() { + if let Some(total) = total_size { // floor to nearest 5% to reduce message spam - let new_progress = (downloaded as f64 / total_size.unwrap() as f64 * 20.0).floor() as i16 * 5; + let new_progress = (downloaded as f64 / total as f64 * 20.0).floor() as i16 * 5; if new_progress > last_progress { last_progress = new_progress; progress(last_progress); @@ -70,7 +71,10 @@ fn test_download_file_reports_progress() { let mut prog_count = 0; let mut last_prog = 0; - download_url_to_file(test_file, "test_download_file_reports_progress.txt", |p| { + let tmpfile = tempfile::NamedTempFile::new().unwrap(); + let tmppath = tmpfile.path(); + + download_url_to_file(test_file, tmppath.to_str().unwrap(), |p| { assert!(p >= last_prog); prog_count += 1; last_prog = p; @@ -81,10 +85,80 @@ fn test_download_file_reports_progress() { assert!(prog_count <= 20); assert_eq!(last_prog, 100); - let p = std::path::Path::new("test_download_file_reports_progress.txt"); - let meta = p.metadata().unwrap(); + let meta = tmppath.metadata().unwrap(); let len = meta.len(); assert_eq!(len, 10 * 1024 * 1024); - std::fs::remove_file(p).unwrap(); } + +#[test] +fn test_interrupted_download() { + use std::io::Write; + use std::net::TcpListener; + use std::thread; + + // Start a simple HTTP server that cuts the connection mid-download + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind test server"); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + if let Ok((mut stream, _)) = listener.accept() { + // Write a partial HTTP response + let response = "HTTP/1.1 200 OK\r\nContent-Length: 100000\r\n\r\n"; + stream.write_all(response.as_bytes()).expect("Failed to write response"); + + // Send part of the data, then close the connection early + let partial_data = vec![0u8; 1024]; // 1 KB + stream.write_all(&partial_data).expect("Failed to write partial data"); + + // Connection closes here, simulating an interrupted download + thread::sleep(std::time::Duration::from_millis(100)); + } + }); + + let tmpfile = tempfile::NamedTempFile::new().unwrap(); + let result = download_url_to_file( + &format!("http://{}", addr), + tmpfile.path().to_str().unwrap(), + |_| {} + ); + + assert!(result.is_err(), "Download should fail due to connection interruption"); +} + +#[test] +fn test_successful_download() { + use std::io::Write; + use std::net::TcpListener; + use std::thread; + + // Start a simple HTTP server + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind test server"); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + if let Ok((mut stream, _)) = listener.accept() { + // Write a full HTTP response with full content + let response = "HTTP/1.1 200 OK\r\nContent-Length: 10240\r\n\r\n"; + stream.write_all(response.as_bytes()).expect("Failed to write response"); + + // Send full 10KB of data + let full_data = vec![0u8; 10240]; + stream.write_all(&full_data).expect("Failed to write full data"); + + // give client time to receive and disconnect + thread::sleep(std::time::Duration::from_millis(100)); + } + }); + + let tmpfile = tempfile::NamedTempFile::new().unwrap(); + let _ = download_url_to_file( + &format!("http://{}", addr), + tmpfile.path().to_str().unwrap(), + |_| {}, + ).unwrap(); + + // Verify that the downloaded file has the expected size + let metadata = tmpfile.path().metadata().unwrap(); + assert_eq!(metadata.len(), 10240, "Downloaded file size should match the expected content size"); +} \ No newline at end of file