From 28668fb64e1dbc11d75b114ea5148d05b513ab96 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Tue, 21 Apr 2026 16:18:16 +0300 Subject: [PATCH 1/2] use multi-threaded XZ decompression via liblzma Replace xz2 crate with liblzma 0.4 (parallel feature) to enable lzma_stream_decoder_mt() This helps remove the decompression bottleneck from XZ decompression speeding flashing up by ~4x Signed-off-by: Benny Zlotnik Assisted-by: claude-opus-4.6 --- Cargo.lock | 43 +++---- Cargo.toml | 2 +- src/fls/decompress.rs | 100 ++++++++++++++++ src/fls/from_url.rs | 260 ++++++++-------------------------------- src/fls/magic_bytes.rs | 2 +- src/fls/oci/from_oci.rs | 27 ++++- src/fls/options.rs | 4 + src/fls/stream_utils.rs | 23 ++-- src/main.rs | 6 + tests/common/mod.rs | 2 +- 10 files changed, 219 insertions(+), 250 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index aa2294e..ac194de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,6 +416,7 @@ dependencies = [ "hyper-util", "indicatif", "libc", + "liblzma", "nix", "openssl", "rcgen", @@ -429,7 +430,6 @@ dependencies = [ "tokio", "tokio-rustls", "wiremock", - "xz2", ] [[package]] @@ -963,6 +963,27 @@ dependencies = [ "windows-link", ] +[[package]] +name = "liblzma" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" +dependencies = [ + "liblzma-sys", + "num_cpus", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a60851d15cd8c5346eca4ab8babff585be2ae4bc8097c067291d3ffe2add3b6" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "libredox" version = "0.1.12" @@ -1007,17 +1028,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "memchr" version = "2.7.6" @@ -2345,15 +2355,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "xz2" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] - [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 361a7f9..9a5773f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ nix = { version = "0.29", features = ["ioctl"] } # Enable vendored OpenSSL for cross-compilation to musl targets # This ensures OpenSSL builds from source with musl compatibility openssl = { version = "0.10", features = ["vendored"] } -xz2 = "0.1" +liblzma = { version = "0.4", features = ["parallel"] } [dev-dependencies] http-body = "1.0.1" diff --git a/src/fls/decompress.rs b/src/fls/decompress.rs index 6084c35..de9f82d 100644 --- a/src/fls/decompress.rs +++ b/src/fls/decompress.rs @@ -1,7 +1,30 @@ +use crate::fls::byte_channel::ByteBoundedReceiver; +use crate::fls::compression::Compression; +use crate::fls::stream_utils::ChannelReader; +use bytes::Bytes; +use std::io::Read; use tokio::io::AsyncReadExt; use tokio::process::{Child, Command}; use tokio::sync::mpsc; +fn mb_to_bytes(mb: u64) -> u64 { + mb.saturating_mul(1024 * 1024) +} + +pub(crate) fn create_xz_decoder( + reader: R, + memlimit_mb: u64, +) -> Result, String> { + let memlimit = mb_to_bytes(memlimit_mb); + let stream = liblzma::stream::Stream::new_stream_decoder(memlimit, 0).map_err(|e| { + format!( + "Failed to create XZ decoder with {}MB limit: {}", + memlimit_mb, e + ) + })?; + Ok(liblzma::read::XzDecoder::new_stream(reader, stream)) +} + /// Determines the appropriate decompression command based on URL extension fn get_decompressor_command(url: &str) -> &'static str { let extension = url.rsplit('.').next().unwrap_or("").to_lowercase(); @@ -75,6 +98,83 @@ fn spawn_decompressor( Ok((process, cmd)) } +pub(crate) fn get_compression_from_url(url: &str) -> Compression { + let path = url.split('?').next().unwrap_or(url); + let path = path.split('#').next().unwrap_or(path); + let extension = path.rsplit('.').next().unwrap_or("").to_lowercase(); + match extension.as_str() { + "gz" => Compression::Gzip, + "xz" => Compression::Xz, + "zst" | "zstd" => Compression::Zstd, + _ => Compression::None, + } +} + +type DecompressorResult = ( + mpsc::Receiver>, + std::thread::JoinHandle>, +); + +pub(crate) fn start_inprocess_decompressor( + buffer_rx: ByteBoundedReceiver, + compression: Compression, + consumed_progress_tx: mpsc::UnboundedSender, + xz_memlimit_mb: u64, +) -> Result> { + let (decompressed_tx, decompressed_rx) = mpsc::channel::>(8); + + let handle = std::thread::Builder::new() + .name("decompressor".to_string()) + .spawn(move || { + let channel_reader = + ChannelReader::new_byte_bounded(buffer_rx).with_progress(consumed_progress_tx); + + let mut decoder: Box = match compression { + Compression::Xz => { + let num_threads = std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(2); + let memlimit = mb_to_bytes(xz_memlimit_mb); + eprintln!( + "XZ decompression: {} threads, memory limit {}MB", + num_threads, xz_memlimit_mb + ); + let stream = liblzma::stream::MtStreamBuilder::new() + .threads(num_threads) + .memlimit_threading(memlimit) + .memlimit_stop(memlimit) + .decoder() + .map_err(|e| format!("Failed to create MT XZ decoder: {}", e))?; + Box::new(liblzma::read::XzDecoder::new_stream(channel_reader, stream)) + } + Compression::Gzip => Box::new(flate2::read::GzDecoder::new(channel_reader)), + Compression::None => Box::new(channel_reader), + Compression::Zstd => { + return Err("Zstd in-process decompression is not supported".to_string()); + } + }; + + let mut buf = vec![0u8; 8 * 1024 * 1024]; + loop { + let n = decoder + .read(&mut buf) + .map_err(|e| format!("Decompression error: {}", e))?; + if n == 0 { + break; + } + if decompressed_tx.blocking_send(buf[..n].to_vec()).is_err() { + return Err("Writer task closed, stopping decompression".to_string()); + } + } + Ok(()) + }) + .map_err(|e| -> Box { + format!("Failed to spawn decompressor thread: {}", e).into() + })?; + + Ok((decompressed_rx, handle)) +} + pub(crate) async fn spawn_stderr_reader( mut stderr: tokio::process::ChildStderr, error_tx: mpsc::UnboundedSender, diff --git a/src/fls/from_url.rs b/src/fls/from_url.rs index fc03f85..6e20da0 100644 --- a/src/fls/from_url.rs +++ b/src/fls/from_url.rs @@ -1,12 +1,13 @@ +use futures_util::StreamExt; use std::io; use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; use tokio::task::JoinHandle; use crate::fls::block_writer::AsyncBlockWriter; use crate::fls::byte_channel::byte_bounded_channel; -use crate::fls::decompress::{spawn_stderr_reader, start_decompressor_process}; +use crate::fls::compression::Compression; +use crate::fls::decompress::{get_compression_from_url, start_inprocess_decompressor}; use crate::fls::download_error::DownloadError; use crate::fls::error_handling::process_error_messages; use crate::fls::format_detector::{DetectionResult, FileFormat, FormatDetector}; @@ -36,19 +37,6 @@ async fn get_writer_error(handle: JoinHandle>) -> Box>, -) -> Box { - match handle.await { - Ok(Ok(())) => "Decompressor stdin closed unexpectedly".into(), - Ok(Err(e)) => e.into(), - Err(e) => format!("Decompressor writer task panicked: {}", e).into(), - } -} - use crate::fls::download_error::handle_download_retry; /// Execute a sequence of write commands on the block writer @@ -179,12 +167,14 @@ pub async fn flash_from_url( let http_options: HttpClientOptions = (&options).into(); let client = setup_http_client(&http_options).await?; - let (mut decompressor, decompressor_name) = start_decompressor_process(url).await?; - - // Extract stdio handles - let mut decompressor_stdin = decompressor.stdin.take().unwrap(); - let decompressor_stdout = decompressor.stdout.take().unwrap(); - let decompressor_stderr = decompressor.stderr.take().unwrap(); + let compression = get_compression_from_url(url); + if compression == Compression::Zstd { + return Err("Zstd in-process decompression is not supported".into()); + } + let is_compressed = compression != Compression::None; + if is_compressed { + eprintln!("Using decompressor: {} (in-process)", compression); + } // Create channels let (decompressed_progress_tx, mut decompressed_progress_rx) = mpsc::unbounded_channel::(); @@ -205,24 +195,43 @@ pub async fn flash_from_url( options.common.write_buffer_size_mb, )?; - // Spawn background task to read from decompressor and write to block device + // Create byte-bounded download buffer + let buffer_size_mb = options.common.buffer_size_mb; + let max_buffer_bytes = buffer_size_mb * 1024 * 1024; + + println!( + "Using download buffer: {} MB (byte-bounded)", + buffer_size_mb + ); + + let (buffer_tx, buffer_rx) = byte_bounded_channel::(max_buffer_bytes, 4096); + + // Channel for tracking bytes consumed from buffer by decompressor + let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) = + mpsc::unbounded_channel::(); + + // Start in-process decompressor thread + let (mut decompressed_rx, decompressor_handle) = start_inprocess_decompressor( + buffer_rx, + compression, + decompressor_written_progress_tx, + options.common.xz_memlimit_mb, + )?; + + // Spawn background task to read decompressed data and write to block device let error_tx_clone = error_tx.clone(); let debug = options.common.debug; let writer_handle = { let writer = block_writer; tokio::spawn(async move { - let mut stdout = decompressor_stdout; - let mut buffer = vec![0u8; 8 * 1024 * 1024]; // 8MB buffer - - // Auto-detect sparse image format from initial data let mut detector = FormatDetector::new(); let mut parser: Option = None; let mut format_determined = false; loop { - let n = match stdout.read(&mut buffer).await { - Ok(0) => { - // EOF - check if we have incomplete format detection + let data = match decompressed_rx.recv().await { + Some(data) => data, + None => { if !format_determined { if let Some(buffered_data) = detector.finalize_at_eof() { if debug { @@ -236,20 +245,15 @@ pub async fn flash_from_url( } break; } - Ok(n) => n, - Err(e) => { - let _ = error_tx_clone - .send(format!("Error reading from decompressor stdout: {}", e)); - return Err(e); - } }; + let n = data.len(); if decompressed_progress_tx.send(n as u64).is_err() { break; } if !format_determined { - match detector.process(&buffer[..n]) { + match detector.process(&data) { DetectionResult::NeedMoreData => { if debug { eprintln!( @@ -263,7 +267,7 @@ pub async fn flash_from_url( consumed_bytes, consumed_from_input, } => { - let remaining = &buffer[consumed_from_input..n]; + let remaining = &data[consumed_from_input..n]; parser = handle_detected_format( format, consumed_bytes, @@ -281,81 +285,34 @@ pub async fn flash_from_url( } } } else if let Some(ref mut p) = parser { - process_sparse_data(p, &buffer[..n], &writer, &error_tx_clone, debug).await?; + process_sparse_data(p, &data, &writer, &error_tx_clone, debug).await?; } else { - write_regular_data(buffer[..n].to_vec(), &writer, &error_tx_clone).await?; + write_regular_data(data, &writer, &error_tx_clone).await?; } } - // Close writer and get final bytes written writer.close().await }) }; - tokio::spawn(spawn_stderr_reader( - decompressor_stderr, - error_tx.clone(), - decompressor_name, - )); - // Spawn message processors let error_processor = tokio::spawn(process_error_messages(error_rx)); // Main download loop with retry logic let mut progress = ProgressTracker::new(options.common.newline_progress, options.common.show_memory); - // Set whether we're actually decompressing (not using cat for uncompressed files) - progress.set_is_compressed(decompressor_name != "cat"); + progress.set_is_compressed(is_compressed); let update_interval = Duration::from_secs_f64(options.common.progress_interval_secs); let mut bytes_sent_to_decompressor: u64 = 0; let mut retry_count = 0; let debug = options.common.debug; - use futures_util::StreamExt; - - // Create byte-bounded download buffer (shared across all retry attempts) - let buffer_size_mb = options.common.buffer_size_mb; - let max_buffer_bytes = buffer_size_mb * 1024 * 1024; - - println!( - "Using download buffer: {} MB (byte-bounded)", - buffer_size_mb - ); - - // Create persistent byte-bounded channel for download buffering (lives across retries) - // max_items=4096 prevents unbounded item queuing; byte budget is the real bound - let (buffer_tx, mut buffer_rx) = byte_bounded_channel::(max_buffer_bytes, 4096); - - // Channels for tracking bytes actually written to decompressor - let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) = - mpsc::unbounded_channel::(); - - // Spawn persistent task to write buffered chunks to decompressor - let decompressor_writer_handle = tokio::spawn(async move { - while let Some(chunk) = buffer_rx.recv().await { - let chunk_len = chunk.len() as u64; - if let Err(e) = decompressor_stdin.write_all(&chunk).await { - return Err(format!("Error writing to decompressor stdin: {}", e)); - } - // Notify that bytes were written to decompressor - let _ = decompressor_written_progress_tx.send(chunk_len); - } - // Close decompressor stdin when channel is closed - Ok::<(), String>(()) - }); - loop { - // Check if writer or decompressor has failed before attempting download/retry if writer_handle.is_finished() { eprintln!(); eprintln!("Writer task has terminated, stopping download"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!("Decompressor writer task has terminated, stopping download"); - return Err(get_decompressor_error(decompressor_writer_handle).await); - } // Resume from the HTTP download position, not the decompressor write position // The buffer may contain data that's been downloaded but not yet written to decompressor @@ -414,21 +371,11 @@ pub async fn flash_from_url( // Send to buffer - detect if it's blocking let send_start = std::time::Instant::now(); if buffer_tx.send(chunk).await.is_err() { - // Check if writer or decompressor has failed if writer_handle.is_finished() { eprintln!(); eprintln!("Writer task has terminated unexpectedly"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!( - "Decompressor writer task has terminated unexpectedly" - ); - return Err( - get_decompressor_error(decompressor_writer_handle).await - ); - } connection_error = Some(DownloadError::Other("Buffer channel closed".to_string())); connection_broken = true; @@ -500,17 +447,11 @@ pub async fn flash_from_url( // If connection broke, retry if connection_broken { - // First check if writer or decompressor has failed - if so, don't retry if writer_handle.is_finished() { eprintln!(); eprintln!("Connection interrupted and writer task has terminated"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!("Connection interrupted and decompressor writer task has terminated"); - return Err(get_decompressor_error(decompressor_writer_handle).await); - } if let Some(e) = connection_error { eprintln!("\nConnection interrupted: {}", e.format_error()); @@ -554,120 +495,21 @@ pub async fn flash_from_url( // Close buffer channel to signal end of download drop(buffer_tx); - // Poll for decompressor writer completion while showing progress - loop { - // Update progress from all channels - let mut updated = false; - - while let Ok(written_len) = decompressor_written_progress_rx.try_recv() { - progress.bytes_sent_to_decompressor += written_len; - updated = true; - } - - while let Ok(byte_count) = decompressed_progress_rx.try_recv() { - progress.bytes_decompressed += byte_count; - updated = true; - } - - while let Ok(written_bytes) = written_progress_rx.try_recv() { - progress.bytes_written = written_bytes; - updated = true; - } - - if updated { - let _ = progress.update_progress(None, update_interval, false); - } - - // Check if decompressor writer task is done - if decompressor_writer_handle.is_finished() { - break; - } - - // Small sleep to avoid busy waiting - tokio::time::sleep(Duration::from_millis(100)).await; - } - - // Get the result from the decompressor writer task - if let Err(e) = decompressor_writer_handle + let decompressor_result = tokio::task::spawn_blocking(move || decompressor_handle.join()) .await - .map_err(|e| e.to_string()) - .and_then(|r| r) - { + .map_err(|_| "Decompressor task panicked")? + .map_err(|_| "Decompressor thread panicked")?; + + if let Err(e) = decompressor_result { eprintln!(); return Err(e.into()); } - // Update any remaining progress while let Ok(byte_count) = decompressed_progress_rx.try_recv() { progress.bytes_decompressed += byte_count; } - - // Check if decompressor has already finished - let decompressor_already_done = match decompressor.try_wait() { - Ok(Some(status)) => { - if !status.success() { - eprintln!(); - return Err(format!( - "{} process failed with status: {:?}", - decompressor_name, - status.code() - ) - .into()); - } - true - } - Ok(None) => false, - Err(e) => { - eprintln!(); - return Err(e.into()); - } - }; - - // Only wait if decompressor is not already done - if !decompressor_already_done { - // Poll for decompressor completion while showing progress - loop { - // Update progress from channels - let mut updated = false; - - while let Ok(byte_count) = decompressed_progress_rx.try_recv() { - progress.bytes_decompressed += byte_count; - updated = true; - } - - while let Ok(written_bytes) = written_progress_rx.try_recv() { - progress.bytes_written = written_bytes; - updated = true; - } - - if updated { - let _ = progress.update_progress(None, update_interval, false); - } - - // Check if decompressor is done (non-blocking check) - match decompressor.try_wait() { - Ok(Some(status)) => { - if !status.success() { - eprintln!(); - return Err(format!( - "{} process failed with status: {:?}", - decompressor_name, - status.code() - ) - .into()); - } - break; - } - Ok(None) => { - // Still running, sleep briefly - tokio::time::sleep(Duration::from_millis(100)).await; - } - Err(e) => { - eprintln!(); - return Err(e.into()); - } - } - } + while let Ok(written_len) = decompressor_written_progress_rx.try_recv() { + progress.bytes_sent_to_decompressor += written_len; } // Capture the decompression rate and duration at completion diff --git a/src/fls/magic_bytes.rs b/src/fls/magic_bytes.rs index dc2d411..fb20894 100644 --- a/src/fls/magic_bytes.rs +++ b/src/fls/magic_bytes.rs @@ -143,8 +143,8 @@ fn decompress_gzip_sample(data: &[u8]) -> Result, String> { /// Decompress a sample of XZ data to analyze content type fn decompress_xz_sample(data: &[u8]) -> Result, String> { + use liblzma::read::XzDecoder; use std::io::Read; - use xz2::read::XzDecoder; let mut decoder = XzDecoder::new(data); let mut buffer = vec![0u8; 8192]; // Decompress up to 8KB diff --git a/src/fls/oci/from_oci.rs b/src/fls/oci/from_oci.rs index 4ab101d..cc40492 100644 --- a/src/fls/oci/from_oci.rs +++ b/src/fls/oci/from_oci.rs @@ -13,7 +13,6 @@ use futures_util::StreamExt; use reqwest::StatusCode; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc; -use xz2::read::XzDecoder; use crate::fls::byte_channel::{byte_bounded_channel, ByteBoundedReceiver, ByteBoundedSender}; use crate::fls::decompress::start_decompressor_for_compression; @@ -454,6 +453,7 @@ async fn stream_blob_to_tar_files( .into_owned_fd() .map_err(|e| format!("Failed to convert decompressor stdout: {}", e))? .into(); + let xz_memlimit_mb = options.common.xz_memlimit_mb; let writer_handle = tokio::task::spawn_blocking(move || -> Result, String> { let reader = std::io::BufReader::new(std_stdout); @@ -488,7 +488,14 @@ async fn stream_blob_to_tar_files( })?; } Compression::Xz => { - let mut decoder = XzDecoder::new(combined); + let mut decoder = + crate::fls::decompress::create_xz_decoder(combined, xz_memlimit_mb) + .map_err(|e| { + format!( + "Failed to create XZ decoder for {}: {}", + file_name, e + ) + })?; std::io::copy(&mut decoder, &mut file).map_err(|e| { format!("Failed to write {}: {}", output_path.display(), e) })?; @@ -1954,6 +1961,7 @@ pub async fn flash_from_oci( // Spawn blocking task: HTTP rx -> gzip -> tar -> tar tx let file_pattern = options.file_pattern.clone(); let debug = options.common.debug; + let xz_memlimit_mb = options.common.xz_memlimit_mb; let tar_extractor_handle = tokio::task::spawn_blocking(move || { extract_tar_archive_from_stream( http_rx, @@ -1962,6 +1970,7 @@ pub async fn flash_from_oci( compression, compression_type, debug, + xz_memlimit_mb, ) }); @@ -1996,6 +2005,7 @@ fn extract_tar_stream_impl( tar_tx: mpsc::Sender>, file_pattern: Option<&str>, debug: bool, + xz_memlimit_mb: u64, ) -> Result<(), String> { if debug { eprintln!("[DEBUG] Tar extractor starting"); @@ -2032,8 +2042,8 @@ fn extract_tar_stream_impl( if debug { eprintln!("[DEBUG] Auto-detected XZ compression from magic bytes - decompressing for tar extraction"); } - // Decompress XZ before tar extraction (like gzip) - let xz_decoder = XzDecoder::new(magic_reader); + let xz_decoder = crate::fls::decompress::create_xz_decoder(magic_reader, xz_memlimit_mb) + .map_err(|e| format!("Failed to create XZ decoder: {}", e))?; Box::new(xz_decoder) } else { if debug { @@ -2567,6 +2577,7 @@ fn extract_tar_archive_from_stream( compression: LayerCompression, compression_type: Compression, debug: bool, + xz_memlimit_mb: u64, ) -> Result<(), String> { let reader = ChannelReader::new_byte_bounded(http_rx); @@ -2611,7 +2622,13 @@ fn extract_tar_archive_from_stream( } }; - extract_tar_stream_impl(decompressed_reader, tar_tx, file_pattern, debug) + extract_tar_stream_impl( + decompressed_reader, + tar_tx, + file_pattern, + debug, + xz_memlimit_mb, + ) } #[cfg(test)] diff --git a/src/fls/options.rs b/src/fls/options.rs index 8e51cf2..695ea3a 100644 --- a/src/fls/options.rs +++ b/src/fls/options.rs @@ -3,6 +3,8 @@ use std::path::PathBuf; pub const DEFAULT_MAX_RETRIES: usize = 10; pub const DEFAULT_RETRY_DELAY_SECS: u64 = 2; +pub const DEFAULT_XZ_MEMLIMIT_MB: u64 = 256; + /// Common options shared between URL and OCI flash operations #[derive(Debug, Clone)] pub struct FlashOptions { @@ -16,6 +18,7 @@ pub struct FlashOptions { pub progress_interval_secs: f64, pub newline_progress: bool, pub show_memory: bool, + pub xz_memlimit_mb: u64, } impl Default for FlashOptions { @@ -31,6 +34,7 @@ impl Default for FlashOptions { progress_interval_secs: 0.5, newline_progress: false, show_memory: false, + xz_memlimit_mb: DEFAULT_XZ_MEMLIMIT_MB, } } } diff --git a/src/fls/stream_utils.rs b/src/fls/stream_utils.rs index e6807ab..0193cc5 100644 --- a/src/fls/stream_utils.rs +++ b/src/fls/stream_utils.rs @@ -1,36 +1,34 @@ use crate::fls::byte_channel::ByteBoundedReceiver; -/// Stream utilities for async/sync bridging and download handling -/// -/// Provides reusable components for streaming data between async HTTP -/// downloads and sync processing (like tar extraction or decompression). use bytes::Bytes; use std::io::Read; +use tokio::sync::mpsc; -/// Reader that pulls bytes from a byte-bounded channel -/// -/// This bridges async HTTP streaming with synchronous readers -/// like tar::Archive or flate2::GzDecoder. pub struct ChannelReader { rx: ByteBoundedReceiver, current: Option, offset: usize, + progress_tx: Option>, } impl ChannelReader { - /// Create a new ChannelReader from a byte-bounded receiver pub fn new_byte_bounded(rx: ByteBoundedReceiver) -> Self { Self { rx, current: None, offset: 0, + progress_tx: None, } } + + pub fn with_progress(mut self, tx: mpsc::UnboundedSender) -> Self { + self.progress_tx = Some(tx); + self + } } impl Read for ChannelReader { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { loop { - // If we have current data, use it if let Some(ref data) = self.current { let remaining = &data[self.offset..]; if !remaining.is_empty() { @@ -41,14 +39,15 @@ impl Read for ChannelReader { } } - // Need more data - blocking receive match self.rx.blocking_recv() { Some(data) => { + if let Some(ref tx) = self.progress_tx { + let _ = tx.send(data.len() as u64); + } self.current = Some(data); self.offset = 0; } None => { - // Channel closed - EOF return Ok(0); } } diff --git a/src/main.rs b/src/main.rs index 14a4b62..a1b5f06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,6 +72,9 @@ enum Commands { /// Show memory statistics in progress display #[arg(long)] show_memory: bool, + /// XZ decompression memory limit in MB (exceeds = single-thread fallback, then error) + #[arg(long, default_value = "256")] + xz_memlimit: u64, /// Registry username for OCI authentication #[arg(short = 'u', long, env = "FLS_REGISTRY_USERNAME")] username: Option, @@ -136,6 +139,7 @@ async fn main() { username, password, file_pattern, + xz_memlimit, } => { // Detect URL scheme to determine handler let is_oci = url.starts_with("oci://"); @@ -188,6 +192,7 @@ async fn main() { progress_interval_secs: progress_interval, newline_progress, show_memory, + xz_memlimit_mb: xz_memlimit, }, username, password, @@ -265,6 +270,7 @@ async fn main() { progress_interval_secs: progress_interval, newline_progress, show_memory, + xz_memlimit_mb: xz_memlimit, }, max_retries, retry_delay_secs: retry_delay, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 5d62adb..b8813f1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,8 +1,8 @@ // Shared test utilities use flate2::write::GzEncoder; use flate2::Compression; +use liblzma::write::XzEncoder; use std::io::Write; -use xz2::write::XzEncoder; /// Generate deterministic test data of a given size #[allow(dead_code)] From 266d6f4eee316adffff77e142be2c9e8594dc913 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sun, 26 Apr 2026 11:33:29 +0300 Subject: [PATCH 2/2] oci: add MT xz support Signed-off-by: Benny Zlotnik Assisted-by: claude-opus-4.6 --- src/fls/decompress.rs | 43 ++++---- src/fls/oci/from_oci.rs | 214 +++++----------------------------------- 2 files changed, 50 insertions(+), 207 deletions(-) diff --git a/src/fls/decompress.rs b/src/fls/decompress.rs index de9f82d..f8cc3b3 100644 --- a/src/fls/decompress.rs +++ b/src/fls/decompress.rs @@ -7,7 +7,7 @@ use tokio::io::AsyncReadExt; use tokio::process::{Child, Command}; use tokio::sync::mpsc; -fn mb_to_bytes(mb: u64) -> u64 { +pub(crate) fn mb_to_bytes(mb: u64) -> u64 { mb.saturating_mul(1024 * 1024) } @@ -25,6 +25,29 @@ pub(crate) fn create_xz_decoder( Ok(liblzma::read::XzDecoder::new_stream(reader, stream)) } +pub(crate) fn create_mt_xz_decoder( + reader: R, + xz_memlimit_mb: u64, +) -> Result, String> { + let num_threads = std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(2); + let memlimit = mb_to_bytes(xz_memlimit_mb); + eprintln!( + "XZ decompression: {} threads, memory limit {}MB", + num_threads, xz_memlimit_mb + ); + let stream = liblzma::stream::MtStreamBuilder::new() + .threads(num_threads) + .memlimit_threading(memlimit) + .memlimit_stop(memlimit) + .decoder() + .map_err(|e| format!("Failed to create MT XZ decoder: {}", e))?; + Ok(Box::new(liblzma::read::XzDecoder::new_stream( + reader, stream, + ))) +} + /// Determines the appropriate decompression command based on URL extension fn get_decompressor_command(url: &str) -> &'static str { let extension = url.rsplit('.').next().unwrap_or("").to_lowercase(); @@ -130,23 +153,7 @@ pub(crate) fn start_inprocess_decompressor( ChannelReader::new_byte_bounded(buffer_rx).with_progress(consumed_progress_tx); let mut decoder: Box = match compression { - Compression::Xz => { - let num_threads = std::thread::available_parallelism() - .map(|n| n.get() as u32) - .unwrap_or(2); - let memlimit = mb_to_bytes(xz_memlimit_mb); - eprintln!( - "XZ decompression: {} threads, memory limit {}MB", - num_threads, xz_memlimit_mb - ); - let stream = liblzma::stream::MtStreamBuilder::new() - .threads(num_threads) - .memlimit_threading(memlimit) - .memlimit_stop(memlimit) - .decoder() - .map_err(|e| format!("Failed to create MT XZ decoder: {}", e))?; - Box::new(liblzma::read::XzDecoder::new_stream(channel_reader, stream)) - } + Compression::Xz => create_mt_xz_decoder(channel_reader, xz_memlimit_mb)?, Compression::Gzip => Box::new(flate2::read::GzDecoder::new(channel_reader)), Compression::None => Box::new(channel_reader), Compression::Zstd => { diff --git a/src/fls/oci/from_oci.rs b/src/fls/oci/from_oci.rs index cc40492..2b1dfcf 100644 --- a/src/fls/oci/from_oci.rs +++ b/src/fls/oci/from_oci.rs @@ -56,7 +56,7 @@ struct DownloadContext { struct RawDiskDownloadParams { http_tx: ByteBoundedSender, writer_handle: tokio::task::JoinHandle>, - external_decompressor: Option, + is_compressed: bool, decompressed_progress_rx: mpsc::UnboundedReceiver, raw_written_progress_rx: mpsc::UnboundedReceiver, } @@ -69,12 +69,6 @@ struct ProcessingHandles { tar_extractor_handle: tokio::task::JoinHandle>, } -/// Components returned by external decompressor pipeline setup -struct ExternalDecompressorPipeline { - writer_handle: tokio::task::JoinHandle>, - decompressor: tokio::process::Child, -} - /// Components returned by pipeline setup struct TarPipelineComponents { http_tx: ByteBoundedSender, @@ -1423,151 +1417,22 @@ async fn coordinate_download_and_processing( Ok(()) } -/// Setup external decompressor pipeline for XZ compression -async fn setup_external_decompressor_pipeline( - http_rx: ByteBoundedReceiver, - block_writer: AsyncBlockWriter, - decompressed_progress_tx: mpsc::UnboundedSender, - debug: bool, -) -> Result> { - // XZ: Use external xzcat process - let (mut decompressor, decompressor_name) = start_decompressor_process("disk.img.xz").await?; - - let decompressor_stdin = decompressor.stdin.take().unwrap(); - let decompressor_stdout = decompressor.stdout.take().unwrap(); - let decompressor_stderr = decompressor.stderr.take().unwrap(); - - let (error_tx, error_rx) = mpsc::unbounded_channel::(); - - // Spawn stderr reader - tokio::spawn(spawn_stderr_reader( - decompressor_stderr, - error_tx.clone(), - decompressor_name, - )); - - // Spawn error processor - tokio::spawn(process_error_messages(error_rx)); - - // Spawn blocking task: read from channel and write to xzcat stdin - // First, create a sync file handle from the async stdin - #[cfg(unix)] - let stdin_fd = { - use std::os::unix::io::{AsRawFd, FromRawFd}; - let raw_fd = decompressor_stdin.as_raw_fd(); - // Duplicate the fd so we can use it in blocking context - let dup_fd = unsafe { libc::dup(raw_fd) }; - if dup_fd == -1 { - return Err(std::io::Error::last_os_error().into()); - } - // SAFETY: dup_fd is a valid file descriptor (we checked above) - unsafe { std::fs::File::from_raw_fd(dup_fd) } - }; - #[cfg(not(unix))] - let stdin_fd: std::fs::File = { - return Err("XZ streaming decompression is not supported on non-unix platforms".into()); - }; - - // Drop the original async stdin (the dup'd fd still points to the pipe) - drop(decompressor_stdin); - - let stdin_writer_handle = { - tokio::task::spawn_blocking(move || { - use std::io::Write as _; - let reader = ChannelReader::new_byte_bounded(http_rx); - let mut reader = reader; - let mut stdin = stdin_fd; - let mut buffer = vec![0u8; 1024 * 1024]; // 1MB chunks - - loop { - match reader.read(&mut buffer) { - Ok(0) => break, // EOF - Ok(n) => { - if let Err(e) = stdin.write_all(&buffer[..n]) { - return Err(format!("Error writing to xzcat: {}", e)); - } - } - Err(e) => return Err(format!("Error reading stream: {}", e)), - } - } - - drop(stdin); - Ok::<(), String>(()) - }) - }; - - // Spawn task: xzcat stdout -> block writer with sparse detection - let writer = block_writer; - let progress_tx = decompressed_progress_tx; - let writer_handle = tokio::spawn(async move { - let mut stdout = decompressor_stdout; - let mut buffer = vec![0u8; 8 * 1024 * 1024]; // 8MB buffer - - // Auto-detect sparse image format from initial data - let mut detector = FormatDetector::new(); - let mut parser: Option = None; - let mut format_determined = false; - - loop { - let n = match tokio::io::AsyncReadExt::read(&mut stdout, &mut buffer).await { - Ok(0) => { - finalize_format_at_eof(&mut detector, format_determined, &writer, debug) - .await?; - break; - } - Ok(n) => n, - Err(e) => return Err(e), - }; - - let _ = progress_tx.send(n as u64); - - process_buffer_with_format_detection( - &buffer, - n, - &mut detector, - &mut parser, - &mut format_determined, - &writer, - debug, - ) - .await?; - } - - // Wait for stdin writer to finish and propagate any errors - stdin_writer_handle - .await - .map_err(|e| std::io::Error::other(format!("Stdin writer task failed: {}", e)))? - .map_err(|e| std::io::Error::other(format!("Stdin writer error: {}", e)))?; - - writer.close().await - }); - - Ok(ExternalDecompressorPipeline { - writer_handle, - decompressor, - }) -} - -/// Setup in-process decompression pipeline for Gzip or None compression async fn setup_inprocess_decompression_pipeline( http_rx: ByteBoundedReceiver, block_writer: AsyncBlockWriter, decompressed_progress_tx: mpsc::UnboundedSender, compression_type: Compression, debug: bool, + xz_memlimit_mb: u64, ) -> Result>, Box> { - // Gzip or None: decompress in-process and write directly to block writer let writer = block_writer; let progress_tx = decompressed_progress_tx; - // Create an async channel for decompressed data let (data_tx, mut data_rx) = mpsc::channel::>(16); - // Spawn blocking task: read, decompress, send to async channel let reader_handle = tokio::task::spawn_blocking(move || { let reader = ChannelReader::new_byte_bounded(http_rx); - // Apply in-process gzip decompression if needed let processed_reader: Box = match compression_type { Compression::Gzip => { if debug { @@ -1575,7 +1440,13 @@ async fn setup_inprocess_decompression_pipeline( } Box::new(flate2::read::GzDecoder::new(reader)) } - _ => Box::new(reader), + Compression::Xz => { + crate::fls::decompress::create_mt_xz_decoder(reader, xz_memlimit_mb)? + } + Compression::None => Box::new(reader), + Compression::Zstd => { + return Err("Zstd in-process decompression is not supported".to_string()); + } }; let mut reader = processed_reader; @@ -1657,7 +1528,7 @@ async fn coordinate_raw_disk_download( let mut progress = ProgressTracker::new(options.common.newline_progress, options.common.show_memory); progress.set_content_length(Some(layer_size)); - progress.set_is_compressed(params.external_decompressor.is_some()); + progress.set_is_compressed(params.is_compressed); progress.bytes_received = initial_buffer.len() as u64; let update_interval = Duration::from_secs_f64(options.common.progress_interval_secs); @@ -1772,24 +1643,6 @@ async fn coordinate_raw_disk_download( progress.decompress_duration = Some(elapsed); progress.write_duration = Some(elapsed); - // Wait for external decompressor process and check exit status - if let Some(mut decompressor) = params.external_decompressor { - match decompressor.wait().await { - Ok(status) => { - if !status.success() { - return Err(format!( - "Decompressor process failed with exit code: {:?}", - status.code() - ) - .into()); - } - } - Err(e) => { - return Err(format!("Failed to wait for decompressor process: {}", e).into()); - } - } - } - // Final progress update let _ = progress.update_progress(Some(layer_size), update_interval, true); @@ -2520,48 +2373,31 @@ async fn flash_raw_disk_image_directly( byte_bounded_channel::(max_buffer_bytes, buffer_capacity); let (decompressed_progress_tx, decompressed_progress_rx) = mpsc::unbounded_channel::(); - // For gzip and none, we can decompress in-process and write directly to block writer - // For XZ, we need the external xzcat process - let needs_external_decompressor = compression_type == Compression::Xz; + if compression_type == Compression::Zstd { + return Err("Zstd in-process decompression is not supported".into()); + } if options.common.debug { eprintln!( - "[DEBUG] Compression type: {:?}, using external decompressor: {}", - compression_type, needs_external_decompressor + "[DEBUG] Compression type: {:?}, using in-process decompressor", + compression_type ); } - // Spawn the processing pipeline based on compression type - let (writer_handle, external_decompressor) = if needs_external_decompressor { - // XZ: Use external xzcat process - let pipeline = setup_external_decompressor_pipeline( - http_rx, - block_writer, - decompressed_progress_tx, - options.common.debug, - ) - .await?; - - (pipeline.writer_handle, Some(pipeline.decompressor)) - } else { - // Gzip or None: decompress in-process and write directly to block writer - let handle = setup_inprocess_decompression_pipeline( - http_rx, - block_writer, - decompressed_progress_tx, - compression_type, - options.common.debug, - ) - .await?; - - (handle, None) - }; + let writer_handle = setup_inprocess_decompression_pipeline( + http_rx, + block_writer, + decompressed_progress_tx, + compression_type, + options.common.debug, + options.common.xz_memlimit_mb, + ) + .await?; - // Coordinate download and processing let params = RawDiskDownloadParams { http_tx, writer_handle, - external_decompressor, + is_compressed: compression_type != Compression::None, decompressed_progress_rx, raw_written_progress_rx, };